diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7803dd0af..6b26d1bbb 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -123,7 +123,10 @@ jobs: defaults: run: working-directory: runner - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest] steps: - uses: actions/checkout@v4 - name: Set up Go @@ -237,8 +240,8 @@ jobs: run: uv sync - name: Generate json schema run: | - uv run python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json(indent=2))" > configuration.json - uv run python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json(indent=2))" > profiles.json + uv run python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json())" > configuration.json + uv run python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json())" > profiles.json - name: Upload json schema to S3 run: | VERSION=$((${{ github.run_number }} + ${{ env.BUILD_INCREMENT }})) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 391761e59..55a6c9243 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -311,8 +311,8 @@ jobs: run: uv sync - name: Generate json schema run: | - uv run python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json(indent=2))" > configuration.json - uv run python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json(indent=2))" > profiles.json + uv run python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json())" > configuration.json + uv run python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json())" > profiles.json - name: Upload json schema to S3 run: | VERSION=${GITHUB_REF#refs/tags/} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d9e8408e2..4984bb121 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: rev: v1.62.0 # Should match .github/workflows/build.yml hooks: - id: golangci-lint-full - language_version: 1.23.0 # Should match runner/go.mod + language_version: 1.23.8 # Should match runner/go.mod entry: bash -c 'cd runner && golangci-lint run' stages: [manual] - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/README.md b/README.md index e48392269..766c35c2b 100644 --- a/README.md +++ b/README.md @@ -21,12 +21,13 @@ `dstack` supports `NVIDIA`, `AMD`, `Google TPU`, `Intel Gaudi`, and `Tenstorrent` accelerators out of the box. ## Latest news ✨ - +- [2025/07] [dstack 0.19.17: Secrets, Files, Rolling deployment](https://github.com/dstackai/dstack/releases/tag/0.19.17) +- [2025/06] [dstack 0.19.16: Docker in Docker, CloudRift](https://github.com/dstackai/dstack/releases/tag/0.19.16) +- [2025/06] [dstack 0.19.13: InfiniBand support in default images](https://github.com/dstackai/dstack/releases/tag/0.19.13) +- [2025/06] [dstack 0.19.12: Simplified use of MPI](https://github.com/dstackai/dstack/releases/tag/0.19.12) +- [2025/05] [dstack 0.19.10: Priorities](https://github.com/dstackai/dstack/releases/tag/0.19.10) - [2025/05] [dstack 0.19.8: Nebius clusters, GH200 on Lambda](https://github.com/dstackai/dstack/releases/tag/0.19.8) - [2025/04] [dstack 0.19.6: Tenstorrent, Plugins](https://github.com/dstackai/dstack/releases/tag/0.19.6) -- [2025/04] [dstack 0.19.5: GCP A3 High clusters](https://github.com/dstackai/dstack/releases/tag/0.19.5) -- [2025/04] [dstack 0.19.3: GCP A3 Mega clusters](https://github.com/dstackai/dstack/releases/tag/0.19.3) -- [2025/03] [dstack 0.19.0: Prometheus](https://github.com/dstackai/dstack/releases/tag/0.19.0) ## How does it work? diff --git a/docs/assets/stylesheets/extra.css b/docs/assets/stylesheets/extra.css index 82e80e658..45d69f906 100644 --- a/docs/assets/stylesheets/extra.css +++ b/docs/assets/stylesheets/extra.css @@ -1318,7 +1318,7 @@ html .md-footer-meta.md-typeset a:is(:focus,:hover) { display: none; } - .md-tabs__item:nth-child(6) { + .md-tabs__item:nth-child(7) { display: none; } @@ -1694,7 +1694,7 @@ a.md-go-to-action.secondary { background: white; } -.md-post__content h2 a { +.md-post__content :is(h2, h3, h4, h5, h6) a { color: rgba(0,0,0,0.87); } diff --git a/docs/blog/posts/amd-on-runpod.md b/docs/blog/posts/amd-on-runpod.md index d9d948cf6..1e32a27e7 100644 --- a/docs/blog/posts/amd-on-runpod.md +++ b/docs/blog/posts/amd-on-runpod.md @@ -4,7 +4,7 @@ date: 2024-08-21 description: "dstack, the open-source AI container orchestration platform, adds support for AMD accelerators, with RunPod as the first supported cloud provider." slug: amd-on-runpod categories: - - Releases + - Changelog --- # Supporting AMD accelerators on RunPod diff --git a/docs/blog/posts/benchmark-amd-containers-and-partitions.md b/docs/blog/posts/benchmark-amd-containers-and-partitions.md new file mode 100644 index 000000000..4df51c126 --- /dev/null +++ b/docs/blog/posts/benchmark-amd-containers-and-partitions.md @@ -0,0 +1,491 @@ +--- +title: "Benchmarking AMD GPUs: bare-metal, containers, partitions" +date: 2025-07-15 +description: "TBA" +slug: benchmark-amd-containers-and-partitions +image: https://dstack.ai/static-assets/static-assets/images/benchmark-amd-containers-and-partitions.png +categories: + - Benchmarks +--- + +# Benchmarking AMD GPUs: bare-metal, containers, partitions + +Our new benchmark explores two important areas for optimizing AI workloads on AMD GPUs: First, do containers introduce a performance penalty for network-intensive tasks compared to a bare-metal setup? Second, how does partitioning a powerful GPU like the MI300X affect its real-world performance for different types of AI workloads? + + + +This benchmark was supported by [Hot Aisle :material-arrow-top-right-thin:{ .external }](https://hotaisle.xyz/){:target="_blank"}, +a provider of AMD GPU bare-metal and VM infrastructure. + + + +## Benchmark 1: Bare-metal vs containers + +### Finding 1: No loss in interconnect bandwidth + +A common concern is that the abstraction layer of containers might slow down communication between GPUs on different nodes. To test this, we measured interconnect performance using two critical methods: high-level RCCL collectives (AllGather, AllReduce) essential for distributed AI, and low-level RDMA write tests for a raw measure of network bandwidth. + +#### AllGather + +The `all_gather` operation is crucial for tasks like tensor-parallel inference, where results from multiple GPUs must be combined. Our tests showed that container performance almost perfectly matched bare-metal across message sizes from 8MB to 16GB. + + + +#### AllReduce + +Similarly, `all_reduce` is the backbone of distributed training, used for synchronizing gradients. Once again, the results were clear: containers performed just as well as bare-metal. + + + +Both bare-metal and container setups achieved nearly identical peak bus bandwidth (around 350 GB/s for 16GB messages), confirming that containerization does not hinder this fundamental collective operation. + +??? info "Variability" + Both setups showed some variability at smaller message sizes—typical behavior due to kernel launch latencies—but converged to stable, identical peak bandwidths for larger transfers. The fluctuations at smaller sizes are likely caused by non-deterministic factors such as CPU-induced pauses during GPU kernel launches, occasionally favoring one setup over the other. + +#### RDMA write + +To isolate the network from any framework overhead, we ran direct device-to-device RDMA write tests. This measures the raw data transfer speed between GPUs in different nodes. + + + +The results were definitive: bidirectional bandwidth was virtually identical in both bare-metal and container environments across all message sizes, from a tiny 2 bytes up to 8MB. + +#### Conclusion + +Our experiments consistently demonstrate that running multi-node AI workloads inside containers does not degrade interconnect performance. The performance of RCCL collectives and raw RDMA bandwidth on AMD GPUs is on par with a bare-metal configuration. This debunks the myth of a "container tax" and validates containers as a first-class choice for scalable AI infrastructure. + +## Benchmark 2: Partition performance isolated vs mesh + +The AMD GPU can be [partitioned :material-arrow-top-right-thin:{ .external }](https://instinct.docs.amd.com/projects/amdgpu-docs/en/latest/gpu-partitioning/mi300x/overview.html){:target="_blank"} into smaller, independent units (e.g., NPS4 mode splits one GPU into four partitions). This promises better memory bandwidth utilization. Does this theoretical gain translate to better performance in practice? + +### Finding 1: Higher performance for isolated partitions + +First, we sought to reproduce and extend findings from the [official ROCm blog :material-arrow-top-right-thin:{ .external }](https://rocm.blogs.amd.com/software-tools-optimization/compute-memory-modes/README.html){:target="_blank"}. We benchmarked the memory bandwidth of a single partition (in CPX/NPS4 mode) against a full, unpartitioned GPU (in SPX/NPS1 mode). + + + +Our results confirmed that a single partition offers superior memory bandwidth. After aggregating the results to ensure an apples-to-apples comparison, we found the partitioned mode delivered consistently higher memory bandwidth across all message sizes, with especially large gains in the 32MB to 128MB range. + +### Finding 2: Worse performance for partition meshes + +Our benchmark showed that isolated partitions in CPX/NPS4 mode deliver strong memory bandwidth. But can these partitions work efficiently together in mesh scenarios? If performance drops when partitions communicate or share load, the GPU loses significant value for real-world workloads. + +#### Data-parallel inference + +We ran eight independent vLLM instances on eight partitions of a single MI300X and compared their combined throughput against one vLLM instance on a single unpartitioned GPU. The single GPU was significantly faster, and the performance gap widened as the request rate increased. The partitions were starved for memory, limiting their ability to handle the KV cache for a high volume of requests. + + + +The degradation stems from increased memory pressure, as each partition has only a fraction of GPU memory, limiting its ability to handle larger workloads efficiently. + +#### Tensor-parallel inference + +We built a toy inference benchmark with PyTorch’s native distributed support to simulate Tensor Parallelism. A single GPU in SPX/NPS1 mode significantly outperformed the combined throughput of 8xCPX/NPS4 partitions. + + + +The gap stems from the overhead of collective operations like `all_gather`, which are needed to synchronize partial outputs across GPU partitions. + +#### Conclusion + +Although GPU partitioning provides a memory bandwidth boost in isolated microbenchmarks, this benefit does not carry over to practical inference scenarios. + +In reality, performance is limited by two factors: + +1. **Reduced memory**: Each partition has only a fraction of the GPU's total HBM, creating a bottleneck for memory-hungry tasks like storing KV caches. +2. **Communication overhead**: When partitions must work together, the cost of communication between them negates the performance gains. + +GPU partitioning is only practical if used dynamically—for instance, to run multiple small development jobs or lightweight models, and then "unfractioning" the GPU back to its full power for larger, more demanding workloads. + +#### Limitations + +1. **Reproducibility**: AMD’s original blog post on partitioning lacked detailed setup information, so we had to reconstruct the benchmarks independently. +2. **Network tuning**: These benchmarks were run on a default, out-of-the-box network configuration. Our results for RCCL (~339 GB/s) and RDMA (~726 Gbps) are slightly below the peak figures [reported by Dell :material-arrow-top-right-thin:{ .external }](https://infohub.delltechnologies.com/en-us/l/generative-ai-in-the-enterprise-with-amd-accelerators/rccl-and-perftest-for-cluster-validation-1/4/){:target="_blank"}. This suggests that further performance could be unlocked with expert tuning of network topology, MTU size, and NCCL environment variables. + +## Benchmark setup + +### Hardware configuration + +Two nodes with below specifications: + +* Dell PowerEdge XE9680 (MI300X) +* CPU: 2 x Intel Xeon Platinum 8462Y+ +* RAM: 2.0 TiB +* GPU: 8 x AMD MI300X +* OS: Ubuntu 22.04.5 LTS +* ROCm: 6.4.1 +* AMD SMI: 25.4.2+aca1101 + +### Benchmark methodology + +The full, reproducible steps are available in our GitHub repository. Below is a summary of the approach. + +#### Creating a fleet + +We first defined a `dstack` [SSH fleet](../../docs/concepts/fleets.md#ssh) to manage the two-node cluster. + +```yaml +type: fleet +name: hotaisle-fleet +placement: any +ssh_config: + user: hotaisle + identity_file: ~/.ssh/id_rsa + hosts: + - hostname: ssh.hotaisle.cloud + port: 22007 + - hostname: ssh.hotaisle.cloud + port: 22015 +``` + +#### Bare-metal + +**RCCL tests** + +1. Install OpenMPI: + +```shell +apt install libopenmpi-dev openmpi-bin +``` + +2. Clone the RCCL tests repository + +```shell +git clone https://github.com/ROCm/rccl-tests.git +``` + +3. Build RCCL tests + +```shell +cd rccl-tests +make MPI=1 MPI_HOME=$OPEN_MPI_HOME +``` + +4. Create a hostfile with node IPs + +```shell +cat > hostfile < + +```shell +$ dstack apply -f .dstack.yml + +Active run my-service already exists. Detected changes that can be updated in-place: +- Repo state (branch, commit, or other) +- File archives +- Configuration properties: + - env + - files + +Update the run? [y/n]: y +⠋ Launching my-service... + + NAME BACKEND PRICE STATUS SUBMITTED + my-service deployment=1 running 11 mins ago + replica=0 deployment=0 aws (us-west-2) $0.0026 terminating 11 mins ago + replica=1 deployment=1 aws (us-west-2) $0.0026 running 1 min ago +``` + + + + + +#### Secrets + +Secrets let you centrally manage sensitive data like API keys and credentials. They’re scoped to a project, managed by project admins, and can be [securely referenced](../../docs/concepts/secrets.md) in run configurations. + +
+ +```yaml hl_lines="7" +type: task +name: train + +image: nvcr.io/nvidia/pytorch:25.05-py3 +registry_auth: + username: $oauthtoken + password: ${{ secrets.ngc_api_key }} + +commands: + - git clone https://github.com/pytorch/examples.git pytorch-examples + - cd pytorch-examples/distributed/ddp-tutorial-series + - pip install -r requirements.txt + - | + torchrun \ + --nproc-per-node=$DSTACK_GPUS_PER_NODE \ + --nnodes=$DSTACK_NODES_NUM \ + multinode.py 50 10 + +resources: + gpu: H100:1..2 + shm_size: 24GB +``` + +
+ +#### Files + +By default, `dstack` mounts the repo directory (where you ran `dstack init`) to all runs. + +If the directory is large or you need files outside of it, use the new [files](../../docs/concepts/dev-environments/#files) property to map specific local paths into the container. + +
+ +```yaml +type: task +name: trl-sft + +files: + - .:examples # Maps the directory where `.dstack.yml` to `/workflow/examples` + - ~/.ssh/id_rsa:/root/.ssh/id_rsa # Maps `~/.ssh/id_rsa` to `/root/.ssh/id_rs + +python: 3.12 + +env: + - HF_TOKEN + - HF_HUB_ENABLE_HF_TRANSFER=1 + - MODEL=Qwen/Qwen2.5-0.5B + - DATASET=stanfordnlp/imdb + +commands: + - uv pip install trl + - | + trl sft \ + --model_name_or_path $MODEL --dataset_name $DATASET + --num_processes $DSTACK_GPUS_PER_NODE + +resources: + gpu: H100:1 +``` + +
+ +#### Tenstorrent + +`dstack` remains committed to supporting multiple GPU vendors—including NVIDIA, AMD, TPUs, and more recently, [Tenstorrent :material-arrow-top-right-thin:{ .external }](https://tenstorrent.com/){:target="_blank"}. The latest release improves Tenstorrent support by handling hosts with multiple N300 cards and adds Docker-in-Docker support. + + + +Huge thanks to the Tenstorrent community for testing these improvements! + +#### Docker in Docker + +Using Docker inside `dstack` run configurations is now even simpler. Just set `docker` to `true` to [enable the use of Docker CLI](../../docs/concepts/tasks.md#docker-in-docker) in your runs—allowing you to build images, run containers, use Docker Compose, and more. + +
+ +```yaml +type: task +name: docker-nvidia-smi + +docker: true + +commands: + - | + docker run --gpus all \ + nvidia/cuda:12.3.0-base-ubuntu22.04 \ + nvidia-smi + +resources: + gpu: H100:1 +``` + +
+ +#### AWS EFA + +EFA is a network interface for EC2 that enables low-latency, high-bandwidth communication between nodes—crucial for scaling distributed deep learning. With `dstack`, EFA is automatically enabled when using supported instance types in fleets. Check out our [example](../../examples/clusters/efa/index.md) + +#### Default Docker images + +If no `image` is specified, `dstack` uses a base Docker image that now comes pre-configured with `uv`, `python`, `pip`, essential CUDA drivers, InfiniBand, and NCCL tests (located at `/opt/nccl-tests/build`). + +
+ +```yaml +type: task +name: nccl-tests + +nodes: 2 + +startup_order: workers-first +stop_criteria: master-done + +env: + - NCCL_DEBUG=INFO +commands: + - | + if [ $DSTACK_NODE_RANK -eq 0 ]; then + mpirun \ + --allow-run-as-root \ + --hostfile $DSTACK_MPI_HOSTFILE \ + -n $DSTACK_GPUS_NUM \ + -N $DSTACK_GPUS_PER_NODE \ + --bind-to none \ + /opt/nccl-tests/build/all_reduce_perf -b 8 -e 8G -f 2 -g 1 + else + sleep infinity + fi + +resources: + gpu: nvidia:1..8 + shm_size: 16GB +``` + +
+ +These images are optimized for common use cases and kept lightweight—ideal for everyday development, training, and inference. + +#### Server performance + +Server-side performance has been improved. With optimized handling and background processing, each server replica can now handle more runs. + +#### Google SSO + +Alongside the open-source version, `dstack` also offers [dstack Enterprise :material-arrow-top-right-thin:{ .external }](https://github.com/dstackai/dstack-enterprise){:target="_blank"} — which adds dedicated support and extra integrations like Single Sign-On (SSO). The latest release introduces support for configuring your company’s Google account for authentication. + + + +If you’d like to learn more about `dstack` Enterprise, [let us know](https://calendly.com/dstackai/discovery-call). + +That’s all for now. + +!!! info "What's next?" + Give dstack a try, and share your feedback—whether it’s [GitHub :material-arrow-top-right-thin:{ .external }](https://github.com/dstackai/dstack){:target="_blank"} issues, PRs, or questions on [Discord :material-arrow-top-right-thin:{ .external }](https://discord.gg/u8SmfwPpMd){:target="_blank"}. We’re eager to hear from you! diff --git a/docs/blog/posts/cursor.md b/docs/blog/posts/cursor.md index bdc1e4a61..a5f960469 100644 --- a/docs/blog/posts/cursor.md +++ b/docs/blog/posts/cursor.md @@ -5,7 +5,7 @@ description: "TBA" slug: cursor image: https://dstack.ai/static-assets/static-assets/images/dstack-cursor-v2.png categories: - - Releases + - Changelog --- # Accessing dev environments with Cursor diff --git a/docs/blog/posts/dstack-metrics.md b/docs/blog/posts/dstack-metrics.md index f06ff3151..f4647d782 100644 --- a/docs/blog/posts/dstack-metrics.md +++ b/docs/blog/posts/dstack-metrics.md @@ -5,7 +5,7 @@ description: "dstack introduces a new CLI command (and API) for monitoring conta slug: dstack-metrics image: https://dstack.ai/static-assets/static-assets/images/dstack-stats-v2.png categories: - - Releases + - Changelog --- # Monitoring essential GPU metrics via CLI diff --git a/docs/blog/posts/dstack-sky-own-cloud-accounts.md b/docs/blog/posts/dstack-sky-own-cloud-accounts.md index ff0b8d182..16b68867c 100644 --- a/docs/blog/posts/dstack-sky-own-cloud-accounts.md +++ b/docs/blog/posts/dstack-sky-own-cloud-accounts.md @@ -4,7 +4,7 @@ date: 2024-06-11 description: "With today's release, dstack Sky supports both options: accessing the GPU marketplace and using your own cloud accounts." slug: dstack-sky-own-cloud-accounts categories: - - Releases + - Changelog --- # dstack Sky now supports your own cloud accounts diff --git a/docs/blog/posts/dstack-sky.md b/docs/blog/posts/dstack-sky.md index 7cfe80097..78d35641c 100644 --- a/docs/blog/posts/dstack-sky.md +++ b/docs/blog/posts/dstack-sky.md @@ -3,7 +3,7 @@ date: 2024-03-11 description: A managed service that enables you to get GPUs at competitive rates from a wide pool of providers. slug: dstack-sky categories: - - Releases + - Changelog --- # Introducing dstack Sky diff --git a/docs/blog/posts/gh200-on-lambda.md b/docs/blog/posts/gh200-on-lambda.md index 970c87b12..1741e6f2e 100644 --- a/docs/blog/posts/gh200-on-lambda.md +++ b/docs/blog/posts/gh200-on-lambda.md @@ -5,7 +5,7 @@ description: "TBA" slug: gh200-on-lambda image: https://dstack.ai/static-assets/static-assets/images/dstack-arm--gh200-lambda-min.png categories: - - Releases + - Changelog --- # Supporting ARM and NVIDIA GH200 on Lambda diff --git a/docs/blog/posts/gpu-blocks-and-proxy-jump.md b/docs/blog/posts/gpu-blocks-and-proxy-jump.md index dc8bea1df..cbf9ab7dc 100644 --- a/docs/blog/posts/gpu-blocks-and-proxy-jump.md +++ b/docs/blog/posts/gpu-blocks-and-proxy-jump.md @@ -5,7 +5,7 @@ description: "TBA" slug: gpu-blocks-and-proxy-jump image: https://dstack.ai/static-assets/static-assets/images/data-centers-and-private-clouds.png categories: - - Releases + - Changelog --- # Introducing GPU blocks and proxy jump for SSH fleets diff --git a/docs/blog/posts/inactivity-duration.md b/docs/blog/posts/inactivity-duration.md index 28ea3f5d7..d04a8eba4 100644 --- a/docs/blog/posts/inactivity-duration.md +++ b/docs/blog/posts/inactivity-duration.md @@ -5,7 +5,7 @@ description: "dstack introduces a new feature that automatically detects and shu slug: inactivity-duration image: https://dstack.ai/static-assets/static-assets/images/inactive-dev-environments-auto-shutdown.png categories: - - Releases + - Changelog --- # Auto-shutdown for inactive dev environments—no idle GPUs diff --git a/docs/blog/posts/instance-volumes.md b/docs/blog/posts/instance-volumes.md index c4f5e3b1b..95b48cee1 100644 --- a/docs/blog/posts/instance-volumes.md +++ b/docs/blog/posts/instance-volumes.md @@ -5,7 +5,7 @@ description: "To simplify caching across runs and the use of NFS, we introduce a image: https://dstack.ai/static-assets/static-assets/images/dstack-instance-volumes.png slug: instance-volumes categories: - - Releases + - Changelog --- # Introducing instance volumes to persist data on instances diff --git a/docs/blog/posts/intel-gaudi.md b/docs/blog/posts/intel-gaudi.md index 7abc69f41..6f95f49d0 100644 --- a/docs/blog/posts/intel-gaudi.md +++ b/docs/blog/posts/intel-gaudi.md @@ -5,7 +5,7 @@ description: "dstack now supports Intel Gaudi accelerators with SSH fleets, simp slug: intel-gaudi image: https://dstack.ai/static-assets/static-assets/images/dstack-intel-gaudi-and-intel-tiber-cloud.png-v2 categories: - - Releases + - Changelog --- # Supporting Intel Gaudi AI accelerators with SSH fleets diff --git a/docs/blog/posts/metrics-ui.md b/docs/blog/posts/metrics-ui.md index 74719af2d..b15bbffc5 100644 --- a/docs/blog/posts/metrics-ui.md +++ b/docs/blog/posts/metrics-ui.md @@ -5,7 +5,7 @@ description: "TBA" slug: metrics-ui image: https://dstack.ai/static-assets/static-assets/images/dstack-metrics-ui-v3-min.png categories: - - Releases + - Changelog --- # Built-in UI for monitoring essential GPU metrics diff --git a/docs/blog/posts/mpi.md b/docs/blog/posts/mpi.md index ef5d68582..5473d64a2 100644 --- a/docs/blog/posts/mpi.md +++ b/docs/blog/posts/mpi.md @@ -5,7 +5,7 @@ description: "TBA" slug: mpi image: https://dstack.ai/static-assets/static-assets/images/dstack-mpi-v2.png categories: - - Releases + - Changelog --- # Supporting MPI and NCCL/RCCL tests diff --git a/docs/blog/posts/nebius.md b/docs/blog/posts/nebius.md index c6f4374db..ef484b0f3 100644 --- a/docs/blog/posts/nebius.md +++ b/docs/blog/posts/nebius.md @@ -5,7 +5,7 @@ description: "TBA" slug: nebius image: https://dstack.ai/static-assets/static-assets/images/dstack-nebius-v2.png categories: - - Releases + - Changelog --- # Supporting GPU provisioning and orchestration on Nebius diff --git a/docs/blog/posts/nvidia-and-amd-on-vultr.md b/docs/blog/posts/nvidia-and-amd-on-vultr.md index eda4519b4..2fb30ebbc 100644 --- a/docs/blog/posts/nvidia-and-amd-on-vultr.md +++ b/docs/blog/posts/nvidia-and-amd-on-vultr.md @@ -5,7 +5,7 @@ description: "Introducing integration with Vultr: The new integration allows Vul slug: nvidia-and-amd-on-vultr image: https://dstack.ai/static-assets/static-assets/images/dstack-vultr.png categories: - - Releases + - Changelog --- # Supporting NVIDIA and AMD accelerators on Vultr diff --git a/docs/blog/posts/prometheus.md b/docs/blog/posts/prometheus.md index 58299dfb4..5482d0c13 100644 --- a/docs/blog/posts/prometheus.md +++ b/docs/blog/posts/prometheus.md @@ -5,7 +5,7 @@ description: "TBA" slug: prometheus image: https://dstack.ai/static-assets/static-assets/images/dstack-prometheus-v3.png categories: - - Releases + - Changelog --- # Exporting GPU, cost, and other metrics to Prometheus diff --git a/docs/blog/posts/tpu-on-gcp.md b/docs/blog/posts/tpu-on-gcp.md index c38bea20a..8fff83cb4 100644 --- a/docs/blog/posts/tpu-on-gcp.md +++ b/docs/blog/posts/tpu-on-gcp.md @@ -4,7 +4,7 @@ date: 2024-09-10 description: "Learn how to use TPUs with dstack for fine-tuning and deploying LLMs, leveraging open-source tools like Hugging Face’s Optimum TPU and vLLM." slug: tpu-on-gcp categories: - - Releases + - Changelog --- # Using TPUs for fine-tuning and deploying LLMs diff --git a/docs/blog/posts/volumes-on-runpod.md b/docs/blog/posts/volumes-on-runpod.md index a6d436790..de0c8d6d0 100644 --- a/docs/blog/posts/volumes-on-runpod.md +++ b/docs/blog/posts/volumes-on-runpod.md @@ -4,7 +4,7 @@ date: 2024-08-13 description: "Learn how to use volumes with dstack to optimize model inference cold start times on RunPod." slug: volumes-on-runpod categories: - - Releases + - Changelog --- # Using volumes to optimize cold starts on RunPod diff --git a/docs/docs/concepts/services.md b/docs/docs/concepts/services.md index 5dd92f19c..a93cacf0d 100644 --- a/docs/docs/concepts/services.md +++ b/docs/docs/concepts/services.md @@ -679,6 +679,61 @@ utilization_policy: [`max_price`](../reference/dstack.yml/service.md#max_price), and among [others](../reference/dstack.yml/service.md). +## Rolling deployment + +To deploy a new version of a service that is already `running`, use `dstack apply`. `dstack` will automatically detect changes and suggest a rolling deployment update. + +
+ +```shell +$ dstack apply -f my-service.dstack.yml + +Active run my-service already exists. Detected changes that can be updated in-place: +- Repo state (branch, commit, or other) +- File archives +- Configuration properties: + - env + - files + +Update the run? [y/n]: +``` + +
+ +If approved, `dstack` gradually updates the service replicas. To update a replica, `dstack` starts a new replica, waits for it to become `running`, then terminates the old replica. This process is repeated for each replica, one at a time. + +You can track the progress of rolling deployment in both `dstack apply` or `dstack ps`. +Older replicas have lower `deployment` numbers; newer ones have higher. + + + +```shell +$ dstack apply -f my-service.dstack.yml + +⠋ Launching my-service... + NAME BACKEND PRICE STATUS SUBMITTED + my-service deployment=1 running 11 mins ago + replica=0 job=0 deployment=0 aws (us-west-2) $0.0026 terminating 11 mins ago + replica=1 job=0 deployment=1 aws (us-west-2) $0.0026 running 1 min ago +``` + +The rolling deployment stops when all replicas are updated or when a new deployment is submitted. + +??? info "Supported properties" + + + Rolling deployment supports changes to the following properties: `port`, `resources`, `volumes`, `docker`, `files`, `image`, `user`, `privileged`, `entrypoint`, `working_dir`, `python`, `nvcc`, `single_branch`, `env`, `shell`, `commands`, as well as changes to [repo](repos.md) or [file](#files) contents. + + Changes to `replicas` and `scaling` can be applied without redeploying replicas. + + Changes to other properties require a full service restart. + + To trigger a rolling deployment when no properties have changed (e.g., after updating [secrets](secrets.md) or to restart all replicas), + make a minor config change, such as adding a dummy [environment variable](#environment-variables). + --8<-- "docs/concepts/snippets/manage-runs.ext" !!! info "What's next?" diff --git a/docs/docs/guides/server-deployment.md b/docs/docs/guides/server-deployment.md index 8ff4034e6..ee47f481a 100644 --- a/docs/docs/guides/server-deployment.md +++ b/docs/docs/guides/server-deployment.md @@ -124,7 +124,11 @@ Postgres has no such limitation and is recommended for production deployment. ### PostgreSQL -To store the server state in Postgres, set the `DSTACK_DATABASE_URL` environment variable. +To store the server state in Postgres, set the `DSTACK_DATABASE_URL` environment variable: + +```shell +$ DSTACK_DATABASE_URL=postgresql+asyncpg://user:password@db-host:5432/dstack dstack server +``` ??? info "Migrate from SQLite to PostgreSQL" You can migrate the existing state from SQLite to PostgreSQL using `pgloader`: diff --git a/docs/docs/reference/environment-variables.md b/docs/docs/reference/environment-variables.md index 4c5d44bd5..3c28ba333 100644 --- a/docs/docs/reference/environment-variables.md +++ b/docs/docs/reference/environment-variables.md @@ -123,6 +123,7 @@ For more details on the options below, refer to the [server deployment](../guide - `DSTACK_DB_POOL_SIZE`{ #DSTACK_DB_POOL_SIZE } - The client DB connections pool size. Defaults to `20`, - `DSTACK_DB_MAX_OVERFLOW`{ #DSTACK_DB_MAX_OVERFLOW } - The client DB connections pool allowed overflow. Defaults to `20`. - `DSTACK_SERVER_BACKGROUND_PROCESSING_FACTOR`{ #DSTACK_SERVER_BACKGROUND_PROCESSING_FACTOR } - The number of background jobs for processing server resources. Increase if you need to process more resources per server replica quickly. Defaults to `1`. +- `DSTACK_SERVER_BACKGROUND_PROCESSING_DISABLED`{ #DSTACK_SERVER_BACKGROUND_PROCESSING_DISABLED } - Disables background processing if set to any value. Useful to run only web frontend and API server. ??? info "Internal environment variables" The following environment variables are intended for development purposes: diff --git a/docs/overrides/header-2.html b/docs/overrides/header-2.html index 2c7e67679..4f8542d38 100644 --- a/docs/overrides/header-2.html +++ b/docs/overrides/header-2.html @@ -62,7 +62,7 @@
GitHub - Discord +
{% if "navigation.tabs.sticky" in features %} diff --git a/docs/overrides/home.html b/docs/overrides/home.html index 8c306dd36..b7c3ee994 100644 --- a/docs/overrides/home.html +++ b/docs/overrides/home.html @@ -509,20 +509,20 @@

FAQ

- +


@@ -539,8 +539,8 @@

dstack Enterprise

- Book a demo + class="md-button md-button--primary external small"> + Request a trial
diff --git a/docs/overrides/main.html b/docs/overrides/main.html index 4a74fb3a8..8db019610 100644 --- a/docs/overrides/main.html +++ b/docs/overrides/main.html @@ -102,38 +102,42 @@ + + Getting started + Concepts + Guides + Reference + + + + + + + {% endblock %} diff --git a/frontend/src/pages/Runs/Details/Logs/index.tsx b/frontend/src/pages/Runs/Details/Logs/index.tsx index aaffbf41c..6fc550103 100644 --- a/frontend/src/pages/Runs/Details/Logs/index.tsx +++ b/frontend/src/pages/Runs/Details/Logs/index.tsx @@ -31,7 +31,7 @@ export const Logs: React.FC = ({ className, projectName, runName, jobSub const writeDataToTerminal = (logs: ILogItem[]) => { logs.forEach((logItem) => { - terminalInstance.current.write(logItem.message); + terminalInstance.current.write(logItem.message.replace(/(? { const { data, isLoading, refreshList, isLoadingMore } = useInfiniteScroll({ useLazyQuery: useLazyGetRunsQuery, - args: { ...filteringRequestParams, limit: DEFAULT_TABLE_PAGE_SIZE }, + args: { ...filteringRequestParams, limit: DEFAULT_TABLE_PAGE_SIZE, job_submissions_limit: 1 }, getPaginationParams: (lastRun) => ({ prev_submitted_at: lastRun.submitted_at }), }); diff --git a/frontend/src/services/project.ts b/frontend/src/services/project.ts index 9e05e444a..c7559784e 100644 --- a/frontend/src/services/project.ts +++ b/frontend/src/services/project.ts @@ -4,6 +4,8 @@ import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react'; import { base64ToArrayBuffer } from 'libs'; import fetchBaseQueryHeaders from 'libs/fetchBaseQueryHeaders'; +const decoder = new TextDecoder('utf-8'); + // Helper function to transform backend response to frontend format // eslint-disable-next-line @typescript-eslint/no-explicit-any const transformProjectResponse = (project: any): IProject => ({ @@ -131,7 +133,7 @@ export const projectApi = createApi({ transformResponse: (response: { logs: ILogItem[]; next_token: string }) => { const logs = response.logs.map((logItem) => ({ ...logItem, - message: base64ToArrayBuffer(logItem.message as string), + message: decoder.decode(base64ToArrayBuffer(logItem.message)), })); return { diff --git a/frontend/src/types/log.d.ts b/frontend/src/types/log.d.ts index eec182c1e..99e9532c8 100644 --- a/frontend/src/types/log.d.ts +++ b/frontend/src/types/log.d.ts @@ -1,7 +1,7 @@ declare interface ILogItem { log_source: 'stdout' | 'stderr'; timestamp: string; - message: string | Uint8Array; + message: string; } declare type TRequestLogsParams = { diff --git a/frontend/src/types/run.d.ts b/frontend/src/types/run.d.ts index eae9ebacc..2e613defb 100644 --- a/frontend/src/types/run.d.ts +++ b/frontend/src/types/run.d.ts @@ -7,6 +7,7 @@ declare type TRunsRequestParams = { prev_run_id?: string; limit?: number; ascending?: boolean; + job_submissions_limit?: number; }; declare type TDeleteRunsRequestParams = { diff --git a/mkdocs.yml b/mkdocs.yml index e6fac58ed..7d93d9623 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -172,7 +172,7 @@ markdown_extensions: - pymdownx.tasklist: custom_checkbox: true - toc: - toc_depth: 5 + toc_depth: 3 permalink: true - attr_list - md_in_html @@ -292,8 +292,9 @@ nav: - TPU: examples/accelerators/tpu/index.md - Intel Gaudi: examples/accelerators/intel/index.md - Tenstorrent: examples/accelerators/tenstorrent/index.md - - Benchmarks: blog/benchmarks.md + - Changelog: blog/changelog.md - Case studies: blog/case-studies.md + - Benchmarks: blog/benchmarks.md - Blog: - blog/index.md # - Discord: https://discord.gg/u8SmfwPpMd" target="_blank diff --git a/pyproject.toml b/pyproject.toml index 47353886b..736ba6768 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "gpuhunt==0.1.6", "argcomplete>=3.5.0", "ignore-python>=0.2.0", + "orjson", ] [project.urls] diff --git a/runner/go.mod b/runner/go.mod index 22dad6466..850ea8253 100644 --- a/runner/go.mod +++ b/runner/go.mod @@ -1,6 +1,6 @@ module github.com/dstackai/dstack/runner -go 1.23 +go 1.23.8 require ( github.com/alexellis/go-execute/v2 v2.2.1 @@ -10,6 +10,7 @@ require ( github.com/docker/docker v26.0.0+incompatible github.com/docker/go-connections v0.5.0 github.com/docker/go-units v0.5.0 + github.com/dstackai/ansistrip v0.0.6 github.com/go-git/go-git/v5 v5.12.0 github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f github.com/gorilla/websocket v1.5.1 @@ -62,6 +63,7 @@ require ( github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.2.2 // indirect + github.com/tidwall/btree v1.7.0 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/ulikunitz/xz v0.5.12 // indirect @@ -77,10 +79,11 @@ require ( golang.org/x/mod v0.17.0 // indirect golang.org/x/net v0.24.0 // indirect golang.org/x/sync v0.7.0 // indirect + golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.20.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240401170217-c3f982113cda // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda // indirect gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - gotest.tools/v3 v3.5.0 // indirect + gotest.tools/v3 v3.5.1 // indirect ) diff --git a/runner/go.sum b/runner/go.sum index 41e133c46..1222fcac8 100644 --- a/runner/go.sum +++ b/runner/go.sum @@ -47,6 +47,8 @@ github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/dstackai/ansistrip v0.0.6 h1:6qqeDNWt8NoqfkY1CxKUvdHpJzBl89LOE3wMwptVpaI= +github.com/dstackai/ansistrip v0.0.6/go.mod h1:w3ejXI0twxDv6bPXhkOaPeYdbwz2nwcrcvFoZGqi9F0= github.com/ebitengine/purego v0.8.1 h1:sdRKd6plj7KYW33EH5As6YKfe8m9zbN9JMrOjNVF/BE= github.com/ebitengine/purego v0.8.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a h1:mATvB/9r/3gvcejNsXKSkQ6lcIaNec2nyfOdlTBR2lU= @@ -171,6 +173,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= +github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= @@ -281,8 +285,9 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/time v0.0.0-20170424234030-8be79e1e0910 h1:bCMaBn7ph495H+x72gEvgcv+mDRd9dElbzo/mVCMxX4= golang.org/x/time v0.0.0-20170424234030-8be79e1e0910/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= @@ -318,5 +323,5 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY= -gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= +gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= +gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= diff --git a/runner/internal/executor/base.go b/runner/internal/executor/base.go index 2163ca920..554bd7646 100644 --- a/runner/internal/executor/base.go +++ b/runner/internal/executor/base.go @@ -10,7 +10,7 @@ import ( type Executor interface { GetHistory(timestamp int64) *schemas.PullResponse - GetJobLogsHistory() []schemas.LogEvent + GetJobWsLogsHistory() []schemas.LogEvent GetRunnerState() string Run(ctx context.Context) error SetCodePath(codePath string) diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index e14ce540d..d2ced5dc3 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -11,6 +11,7 @@ import ( "os/exec" osuser "os/user" "path/filepath" + "runtime" "strconv" "strings" "sync" @@ -18,6 +19,7 @@ import ( "time" "github.com/creack/pty" + "github.com/dstackai/ansistrip" "github.com/dstackai/dstack/runner/consts" "github.com/dstackai/dstack/runner/internal/connections" "github.com/dstackai/dstack/runner/internal/gerrors" @@ -27,6 +29,24 @@ import ( "github.com/prometheus/procfs" ) +// TODO: Tune these parameters for optimal experience/performance +const ( + // Output is flushed when the cursor doesn't move for this duration + AnsiStripFlushInterval = 500 * time.Millisecond + + // Output is flushed regardless of cursor activity after this maximum delay + AnsiStripMaxDelay = 3 * time.Second + + // Maximum buffer size for ansistrip + MaxBufferSize = 32 * 1024 // 32KB +) + +type ConnectionTracker interface { + GetNoConnectionsSecs() int64 + Track(ticker <-chan time.Time) + Stop() +} + type RunExecutor struct { tempDir string homeDir string @@ -47,13 +67,21 @@ type RunExecutor struct { state string jobStateHistory []schemas.JobStateEvent jobLogs *appendWriter + jobWsLogs *appendWriter runnerLogs *appendWriter timestamp *MonotonicTimestamp killDelay time.Duration - connectionTracker *connections.ConnectionTracker + connectionTracker ConnectionTracker } +// stubConnectionTracker is a no-op implementation for when procfs is not available (only required for tests on darwin) +type stubConnectionTracker struct{} + +func (s *stubConnectionTracker) GetNoConnectionsSecs() int64 { return 0 } +func (s *stubConnectionTracker) Track(ticker <-chan time.Time) {} +func (s *stubConnectionTracker) Stop() {} + func NewRunExecutor(tempDir string, homeDir string, workingDir string, sshPort int) (*RunExecutor, error) { mu := &sync.RWMutex{} timestamp := NewMonotonicTimestamp() @@ -65,15 +93,25 @@ func NewRunExecutor(tempDir string, homeDir string, workingDir string, sshPort i if err != nil { return nil, fmt.Errorf("failed to parse current user uid: %w", err) } - proc, err := procfs.NewDefaultFS() - if err != nil { - return nil, fmt.Errorf("failed to initialize procfs: %w", err) + + // Try to initialize procfs, but don't fail if it's not available (e.g., on macOS) + var connectionTracker ConnectionTracker + + if runtime.GOOS == "linux" { + proc, err := procfs.NewDefaultFS() + if err != nil { + return nil, fmt.Errorf("failed to initialize procfs: %w", err) + } + connectionTracker = connections.NewConnectionTracker(connections.ConnectionTrackerConfig{ + Port: uint64(sshPort), + MinConnDuration: 10 * time.Second, // shorter connections are likely from dstack-server + Procfs: proc, + }) + } else { + // Use stub connection tracker (only required for tests on darwin) + connectionTracker = &stubConnectionTracker{} } - connectionTracker := connections.NewConnectionTracker(connections.ConnectionTrackerConfig{ - Port: uint64(sshPort), - MinConnDuration: 10 * time.Second, // shorter connections are likely from dstack-server - Procfs: proc, - }) + return &RunExecutor{ tempDir: tempDir, homeDir: homeDir, @@ -86,6 +124,7 @@ func NewRunExecutor(tempDir string, homeDir string, workingDir string, sshPort i state: WaitSubmit, jobStateHistory: make([]schemas.JobStateEvent, 0), jobLogs: newAppendWriter(mu, timestamp), + jobWsLogs: newAppendWriter(mu, timestamp), runnerLogs: newAppendWriter(mu, timestamp), timestamp: timestamp, @@ -129,7 +168,9 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { } }() - logger := io.MultiWriter(runnerLogFile, os.Stdout, ex.runnerLogs) + stripper := ansistrip.NewWriter(ex.runnerLogs, AnsiStripFlushInterval, AnsiStripMaxDelay, MaxBufferSize) + defer stripper.Close() + logger := io.MultiWriter(runnerLogFile, os.Stdout, stripper) ctx = log.WithLogger(ctx, log.NewEntry(logger, int(log.DefaultEntry.Logger.Level))) // todo loglevel log.Info(ctx, "Run job", "log_level", log.GetLogger(ctx).Logger.Level.String()) @@ -431,7 +472,9 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error defer func() { _ = ptm.Close() }() defer func() { _ = cmd.Wait() }() // release resources if copy fails - logger := io.MultiWriter(jobLogFile, ex.jobLogs) + stripper := ansistrip.NewWriter(ex.jobLogs, AnsiStripFlushInterval, AnsiStripMaxDelay, MaxBufferSize) + defer stripper.Close() + logger := io.MultiWriter(jobLogFile, ex.jobWsLogs, stripper) _, err = io.Copy(logger, ptm) if err != nil && !isPtyError(err) { return gerrors.Wrap(err) diff --git a/runner/internal/executor/executor_test.go b/runner/internal/executor/executor_test.go index 8d275b137..e13184513 100644 --- a/runner/internal/executor/executor_test.go +++ b/runner/internal/executor/executor_test.go @@ -9,6 +9,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "testing" "time" @@ -17,8 +18,6 @@ import ( "github.com/stretchr/testify/require" ) -// todo test get history - func TestExecutor_WorkingDir_Current(t *testing.T) { var b bytes.Buffer ex := makeTestExecutor(t) @@ -28,7 +27,8 @@ func TestExecutor_WorkingDir_Current(t *testing.T) { err := ex.execJob(context.TODO(), io.Writer(&b)) assert.NoError(t, err) - assert.Equal(t, ex.workingDir+"\r\n", b.String()) + // Normalize line endings for cross-platform compatibility. + assert.Equal(t, ex.workingDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n")) } func TestExecutor_WorkingDir_Nil(t *testing.T) { @@ -39,7 +39,7 @@ func TestExecutor_WorkingDir_Nil(t *testing.T) { err := ex.execJob(context.TODO(), io.Writer(&b)) assert.NoError(t, err) - assert.Equal(t, ex.workingDir+"\r\n", b.String()) + assert.Equal(t, ex.workingDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n")) } func TestExecutor_HomeDir(t *testing.T) { @@ -49,7 +49,7 @@ func TestExecutor_HomeDir(t *testing.T) { err := ex.execJob(context.TODO(), io.Writer(&b)) assert.NoError(t, err) - assert.Equal(t, ex.homeDir+"\r\n", b.String()) + assert.Equal(t, ex.homeDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n")) } func TestExecutor_NonZeroExit(t *testing.T) { @@ -61,7 +61,7 @@ func TestExecutor_NonZeroExit(t *testing.T) { assert.Error(t, err) assert.NotEmpty(t, ex.jobStateHistory) exitStatus := ex.jobStateHistory[len(ex.jobStateHistory)-1].ExitStatus - assert.NotNil(t, exitStatus, ex.jobStateHistory) + assert.NotNil(t, exitStatus) assert.Equal(t, 100, *exitStatus) } @@ -96,7 +96,7 @@ func TestExecutor_LocalRepo(t *testing.T) { err = ex.execJob(context.TODO(), io.Writer(&b)) assert.NoError(t, err) - assert.Equal(t, "bar\r\n", b.String()) + assert.Equal(t, "bar\n", strings.ReplaceAll(b.String(), "\r\n", "\n")) } func TestExecutor_Recover(t *testing.T) { @@ -148,8 +148,8 @@ func TestExecutor_RemoteRepo(t *testing.T) { err = ex.execJob(context.TODO(), io.Writer(&b)) assert.NoError(t, err) - expected := fmt.Sprintf("%s\r\n%s\r\n%s\r\n", ex.getRepoData().RepoHash, ex.getRepoData().RepoConfigName, ex.getRepoData().RepoConfigEmail) - assert.Equal(t, expected, b.String()) + expected := fmt.Sprintf("%s\n%s\n%s\n", ex.getRepoData().RepoHash, ex.getRepoData().RepoConfigName, ex.getRepoData().RepoConfigEmail) + assert.Equal(t, expected, strings.ReplaceAll(b.String(), "\r\n", "\n")) } /* Helpers */ @@ -236,3 +236,98 @@ func TestWriteDstackProfile(t *testing.T) { assert.Equal(t, value, string(out)) } } + +func TestExecutor_Logs(t *testing.T) { + var b bytes.Buffer + ex := makeTestExecutor(t) + // Use printf to generate ANSI control codes. + // \033[31m = red text, \033[1;32m = bold green text, \033[0m = reset + ex.jobSpec.Commands = append(ex.jobSpec.Commands, "printf '\\033[31mRed Hello World\\033[0m\\n' && printf '\\033[1;32mBold Green Line 2\\033[0m\\n' && printf 'Line 3\\n'") + + err := ex.execJob(context.TODO(), io.Writer(&b)) + assert.NoError(t, err) + + logHistory := ex.GetHistory(0).JobLogs + assert.NotEmpty(t, logHistory) + + logString := combineLogMessages(logHistory) + normalizedLogString := strings.ReplaceAll(logString, "\r\n", "\n") + + expectedOutput := "Red Hello World\nBold Green Line 2\nLine 3\n" + assert.Equal(t, expectedOutput, normalizedLogString, "Should strip ANSI codes from regular logs") + + // Verify timestamps are in order + assert.Greater(t, len(logHistory), 0) + for i := 1; i < len(logHistory); i++ { + assert.GreaterOrEqual(t, logHistory[i].Timestamp, logHistory[i-1].Timestamp) + } +} + +func TestExecutor_LogsWithErrors(t *testing.T) { + var b bytes.Buffer + ex := makeTestExecutor(t) + ex.jobSpec.Commands = append(ex.jobSpec.Commands, "echo 'Success message' && echo 'Error message' >&2 && exit 1") + + err := ex.execJob(context.TODO(), io.Writer(&b)) + assert.Error(t, err) + + logHistory := ex.GetHistory(0).JobLogs + assert.NotEmpty(t, logHistory) + + logString := combineLogMessages(logHistory) + normalizedLogString := strings.ReplaceAll(logString, "\r\n", "\n") + + expectedOutput := "Success message\nError message\n" + assert.Equal(t, expectedOutput, normalizedLogString) +} + +func TestExecutor_LogsAnsiCodeHandling(t *testing.T) { + var b bytes.Buffer + ex := makeTestExecutor(t) + + // Test a variety of ANSI escape sequences on stdout and stderr. + cmd := "printf '\\033[31mRed\\033[0m \\033[32mGreen\\033[0m\\n' && " + + "printf '\\033[1mBold\\033[0m \\033[4mUnderline\\033[0m\\n' && " + + "printf '\\033[s\\033[uPlain text\\n' >&2" + + ex.jobSpec.Commands = append(ex.jobSpec.Commands, cmd) + + err := ex.execJob(context.TODO(), io.Writer(&b)) + assert.NoError(t, err) + + // 1. Check WebSocket logs, which should preserve ANSI codes. + wsLogHistory := ex.GetJobWsLogsHistory() + assert.NotEmpty(t, wsLogHistory) + wsLogString := combineLogMessages(wsLogHistory) + normalizedWsLogString := strings.ReplaceAll(wsLogString, "\r\n", "\n") + + expectedWsOutput := "\033[31mRed\033[0m \033[32mGreen\033[0m\n" + + "\033[1mBold\033[0m \033[4mUnderline\033[0m\n" + + "\033[s\033[uPlain text\n" + assert.Equal(t, expectedWsOutput, normalizedWsLogString, "Websocket logs should preserve ANSI codes") + + // 2. Check regular job logs, which should have ANSI codes stripped. + regularLogHistory := ex.GetHistory(0).JobLogs + assert.NotEmpty(t, regularLogHistory) + regularLogString := combineLogMessages(regularLogHistory) + normalizedRegularLogString := strings.ReplaceAll(regularLogString, "\r\n", "\n") + + expectedRegularOutput := "Red Green\n" + + "Bold Underline\n" + + "Plain text\n" + assert.Equal(t, expectedRegularOutput, normalizedRegularLogString, "Regular logs should have ANSI codes stripped") + + // Verify timestamps are ordered for both log types. + assert.Greater(t, len(wsLogHistory), 0) + for i := 1; i < len(wsLogHistory); i++ { + assert.GreaterOrEqual(t, wsLogHistory[i].Timestamp, wsLogHistory[i-1].Timestamp) + } +} + +func combineLogMessages(logHistory []schemas.LogEvent) string { + var logOutput bytes.Buffer + for _, logEvent := range logHistory { + logOutput.Write(logEvent.Message) + } + return logOutput.String() +} diff --git a/runner/internal/executor/query.go b/runner/internal/executor/query.go index 1dff4e330..6678e5f8d 100644 --- a/runner/internal/executor/query.go +++ b/runner/internal/executor/query.go @@ -4,8 +4,8 @@ import ( "github.com/dstackai/dstack/runner/internal/schemas" ) -func (ex *RunExecutor) GetJobLogsHistory() []schemas.LogEvent { - return ex.jobLogs.history +func (ex *RunExecutor) GetJobWsLogsHistory() []schemas.LogEvent { + return ex.jobWsLogs.history } func (ex *RunExecutor) GetHistory(timestamp int64) *schemas.PullResponse { diff --git a/runner/internal/metrics/metrics_test.go b/runner/internal/metrics/metrics_test.go index 7f280da25..d547e2e33 100644 --- a/runner/internal/metrics/metrics_test.go +++ b/runner/internal/metrics/metrics_test.go @@ -1,6 +1,7 @@ package metrics import ( + "runtime" "testing" "github.com/dstackai/dstack/runner/internal/schemas" @@ -8,6 +9,9 @@ import ( ) func TestGetAMDGPUMetrics_OK(t *testing.T) { + if runtime.GOOS == "darwin" { + t.Skip("Skipping on macOS") + } collector, err := NewMetricsCollector() assert.NoError(t, err) @@ -39,6 +43,9 @@ func TestGetAMDGPUMetrics_OK(t *testing.T) { } func TestGetAMDGPUMetrics_ErrorGPUUtilNA(t *testing.T) { + if runtime.GOOS == "darwin" { + t.Skip("Skipping on macOS") + } collector, err := NewMetricsCollector() assert.NoError(t, err) metrics, err := collector.getAMDGPUMetrics("gpu,gfx,gfx_clock,vram_used,vram_total\n0,N/A,N/A,283,196300\n") diff --git a/runner/internal/runner/api/ws.go b/runner/internal/runner/api/ws.go index cade1170a..ebb0caea2 100644 --- a/runner/internal/runner/api/ws.go +++ b/runner/internal/runner/api/ws.go @@ -34,23 +34,23 @@ func (s *Server) streamJobLogs(conn *websocket.Conn) { for { s.executor.RLock() - jobLogsHistory := s.executor.GetJobLogsHistory() + jobLogsWsHistory := s.executor.GetJobWsLogsHistory() select { case <-s.shutdownCh: - if currentPos >= len(jobLogsHistory) { + if currentPos >= len(jobLogsWsHistory) { s.executor.RUnlock() close(s.wsDoneCh) return } default: - if currentPos >= len(jobLogsHistory) { + if currentPos >= len(jobLogsWsHistory) { s.executor.RUnlock() time.Sleep(100 * time.Millisecond) continue } } - for currentPos < len(jobLogsHistory) { - if err := conn.WriteMessage(websocket.BinaryMessage, jobLogsHistory[currentPos].Message); err != nil { + for currentPos < len(jobLogsWsHistory) { + if err := conn.WriteMessage(websocket.BinaryMessage, jobLogsWsHistory[currentPos].Message); err != nil { s.executor.RUnlock() log.Error(context.TODO(), "Failed to write message", "err", err) return diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 19462c9d7..1834188ae 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -274,6 +274,13 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error { cfg := task.config var err error + runnerDir, err := d.dockerParams.MakeRunnerDir(task.containerName) + if err != nil { + return tracerr.Wrap(err) + } + task.runnerDir = runnerDir + log.Debug(ctx, "runner dir", "task", task.ID, "path", runnerDir) + if cfg.GPU != 0 { gpuIDs, err := d.gpuLock.Acquire(ctx, cfg.GPU) if err != nil { @@ -335,7 +342,10 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error { if err := d.tasks.Update(task); err != nil { return tracerr.Errorf("%w: failed to update task %s: %w", ErrInternal, task.ID, err) } - if err = pullImage(pullCtx, d.client, cfg); err != nil { + // Although it's called "runner dir", we also use it for shim task-related data. + // Maybe we should rename it to "task dir" (including the `/root/.dstack/runners` dir on the host). + pullLogPath := filepath.Join(runnerDir, "pull.log") + if err = pullImage(pullCtx, d.client, cfg, pullLogPath); err != nil { errMessage := fmt.Sprintf("pullImage error: %s", err.Error()) log.Error(ctx, errMessage) task.SetStatusTerminated(string(types.TerminationReasonCreatingContainerError), errMessage) @@ -655,7 +665,7 @@ func mountDisk(ctx context.Context, deviceName, mountPoint string, fsRootPerms o return nil } -func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConfig) error { +func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConfig, logPath string) error { if !strings.Contains(taskConfig.ImageName, ":") { taskConfig.ImageName += ":latest" } @@ -685,51 +695,70 @@ func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConf if err != nil { return tracerr.Wrap(err) } - defer func() { _ = reader.Close() }() + defer reader.Close() + + logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) + if err != nil { + return tracerr.Wrap(err) + } + defer logFile.Close() + + teeReader := io.TeeReader(reader, logFile) current := make(map[string]uint) total := make(map[string]uint) - type ProgressDetail struct { - Current uint `json:"current"` - Total uint `json:"total"` - } - type Progress struct { - Id string `json:"id"` - Status string `json:"status"` - ProgressDetail ProgressDetail `json:"progressDetail"` //nolint:tagliatelle - Error string `json:"error"` + // dockerd reports pulling progress as a stream of JSON Lines. The format of records is not documented in the API documentation, + // although it's occasionally mentioned, e.g., https://docs.docker.com/reference/api/engine/version-history/#v148-api-changes + + // https://github.com/moby/moby/blob/e77ff99ede5ee5952b3a9227863552ae6e5b6fb1/pkg/jsonmessage/jsonmessage.go#L144 + // All fields are optional + type PullMessage struct { + Id string `json:"id"` // layer id + Status string `json:"status"` + ProgressDetail struct { + Current uint `json:"current"` // bytes + Total uint `json:"total"` // bytes + } `json:"progressDetail"` + ErrorDetail struct { + Message string `json:"message"` + } `json:"errorDetail"` } - var status bool + var pullCompleted bool pullErrors := make([]string, 0) - scanner := bufio.NewScanner(reader) + scanner := bufio.NewScanner(teeReader) for scanner.Scan() { line := scanner.Bytes() - var progressRow Progress - if err := json.Unmarshal(line, &progressRow); err != nil { + var pullMessage PullMessage + if err := json.Unmarshal(line, &pullMessage); err != nil { continue } - if progressRow.Status == "Downloading" { - current[progressRow.Id] = progressRow.ProgressDetail.Current - total[progressRow.Id] = progressRow.ProgressDetail.Total + if pullMessage.Status == "Downloading" { + current[pullMessage.Id] = pullMessage.ProgressDetail.Current + total[pullMessage.Id] = pullMessage.ProgressDetail.Total } - if progressRow.Status == "Download complete" { - current[progressRow.Id] = total[progressRow.Id] + if pullMessage.Status == "Download complete" { + current[pullMessage.Id] = total[pullMessage.Id] } - if progressRow.Error != "" { - log.Error(ctx, "error pulling image", "name", taskConfig.ImageName, "err", progressRow.Error) - pullErrors = append(pullErrors, progressRow.Error) + if pullMessage.ErrorDetail.Message != "" { + log.Error(ctx, "error pulling image", "name", taskConfig.ImageName, "err", pullMessage.ErrorDetail.Message) + pullErrors = append(pullErrors, pullMessage.ErrorDetail.Message) } - if strings.HasPrefix(progressRow.Status, "Status:") { - status = true - log.Debug(ctx, progressRow.Status) + // If the pull is successful, the last two entries must be: + // "Digest: sha256:" + // "Status: " + // where is either "Downloaded newer image for " or "Image is up to date for ". + // See: https://github.com/moby/moby/blob/e77ff99ede5ee5952b3a9227863552ae6e5b6fb1/daemon/containerd/image_pull.go#L134-L152 + // See: https://github.com/moby/moby/blob/e77ff99ede5ee5952b3a9227863552ae6e5b6fb1/daemon/containerd/image_pull.go#L257-L263 + if strings.HasPrefix(pullMessage.Status, "Status:") { + pullCompleted = true + log.Debug(ctx, pullMessage.Status) } } duration := time.Since(startTime) - var currentBytes uint var totalBytes uint for _, v := range current { @@ -738,9 +767,13 @@ func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConf for _, v := range total { totalBytes += v } - speed := bytesize.New(float64(currentBytes) / duration.Seconds()) - if status && currentBytes == totalBytes { + + if err := ctx.Err(); err != nil { + return tracerr.Errorf("image pull interrupted: downloaded %d bytes out of %d (%s/s): %w", currentBytes, totalBytes, speed, err) + } + + if pullCompleted { log.Debug(ctx, "image successfully pulled", "bytes", currentBytes, "bps", speed) } else { return tracerr.Errorf( @@ -749,21 +782,11 @@ func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConf ) } - err = ctx.Err() - if err != nil { - return tracerr.Errorf("imagepull interrupted: downloaded %d bytes out of %d (%s/s): %w", currentBytes, totalBytes, speed, err) - } return nil } func (d *DockerRunner) createContainer(ctx context.Context, task *Task) error { - runnerDir, err := d.dockerParams.MakeRunnerDir(task.containerName) - if err != nil { - return tracerr.Wrap(err) - } - task.runnerDir = runnerDir - - mounts, err := d.dockerParams.DockerMounts(runnerDir) + mounts, err := d.dockerParams.DockerMounts(task.runnerDir) if err != nil { return tracerr.Wrap(err) } diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go index 35c8cbab6..29a5e1afd 100644 --- a/runner/internal/shim/docker_test.go +++ b/runner/internal/shim/docker_test.go @@ -26,8 +26,9 @@ func TestDocker_SSHServer(t *testing.T) { t.Parallel() params := &dockerParametersMock{ - commands: []string{"echo 1"}, - sshPort: nextPort(), + commands: []string{"echo 1"}, + sshPort: nextPort(), + runnerDir: t.TempDir(), } timeout := 180 // seconds @@ -58,6 +59,7 @@ func TestDocker_SSHServerConnect(t *testing.T) { commands: []string{"sleep 5"}, sshPort: nextPort(), publicSSHKey: string(publicBytes), + runnerDir: t.TempDir(), } timeout := 180 // seconds @@ -103,7 +105,8 @@ func TestDocker_ShmNoexecByDefault(t *testing.T) { t.Parallel() params := &dockerParametersMock{ - commands: []string{"mount | grep '/dev/shm .*size=65536k' | grep noexec"}, + commands: []string{"mount | grep '/dev/shm .*size=65536k' | grep noexec"}, + runnerDir: t.TempDir(), } timeout := 180 // seconds @@ -125,7 +128,8 @@ func TestDocker_ShmExecIfSizeSpecified(t *testing.T) { t.Parallel() params := &dockerParametersMock{ - commands: []string{"mount | grep '/dev/shm .*size=1024k' | grep -v noexec"}, + commands: []string{"mount | grep '/dev/shm .*size=1024k' | grep -v noexec"}, + runnerDir: t.TempDir(), } timeout := 180 // seconds @@ -148,6 +152,7 @@ type dockerParametersMock struct { commands []string sshPort int publicSSHKey string + runnerDir string } func (c *dockerParametersMock) DockerPrivileged() bool { @@ -184,7 +189,7 @@ func (c *dockerParametersMock) DockerMounts(string) ([]mount.Mount, error) { } func (c *dockerParametersMock) MakeRunnerDir(string) (string, error) { - return "", nil + return c.runnerDir, nil } /* Utilities */ diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py index b501f0b9d..2a7eeb4d5 100644 --- a/src/dstack/_internal/cli/services/configurators/fleet.py +++ b/src/dstack/_internal/cli/services/configurators/fleet.py @@ -25,6 +25,7 @@ ServerClientError, URLNotFoundError, ) +from dstack._internal.core.models.common import ApplyAction from dstack._internal.core.models.configurations import ApplyConfigurationType from dstack._internal.core.models.fleets import ( Fleet, @@ -72,7 +73,104 @@ def apply_configuration( spec=spec, ) _print_plan_header(plan) + if plan.action is not None: + self._apply_plan(plan, command_args) + else: + # Old servers don't support spec update + self._apply_plan_on_old_server(plan, command_args) + + def _apply_plan(self, plan: FleetPlan, command_args: argparse.Namespace): + delete_fleet_name: Optional[str] = None + action_message = "" + confirm_message = "" + if plan.current_resource is None: + if plan.spec.configuration.name is not None: + action_message += ( + f"Fleet [code]{plan.spec.configuration.name}[/] does not exist yet." + ) + confirm_message += "Create the fleet?" + else: + action_message += f"Found fleet [code]{plan.spec.configuration.name}[/]." + if plan.action == ApplyAction.CREATE: + delete_fleet_name = plan.current_resource.name + action_message += ( + " Configuration changes detected. Cannot update the fleet in-place" + ) + confirm_message += "Re-create the fleet?" + elif plan.current_resource.spec == plan.effective_spec: + if command_args.yes and not command_args.force: + # --force is required only with --yes, + # otherwise we may ask for force apply interactively. + console.print( + "No configuration changes detected. Use --force to apply anyway." + ) + return + delete_fleet_name = plan.current_resource.name + action_message += " No configuration changes detected." + confirm_message += "Re-create the fleet?" + else: + action_message += " Configuration changes detected." + confirm_message += "Update the fleet in-place?" + + console.print(action_message) + if not command_args.yes and not confirm_ask(confirm_message): + console.print("\nExiting...") + return + + if delete_fleet_name is not None: + with console.status("Deleting existing fleet..."): + self.api.client.fleets.delete( + project_name=self.api.project, names=[delete_fleet_name] + ) + # Fleet deletion is async. Wait for fleet to be deleted. + while True: + try: + self.api.client.fleets.get( + project_name=self.api.project, name=delete_fleet_name + ) + except ResourceNotExistsError: + break + else: + time.sleep(1) + + try: + with console.status("Applying plan..."): + fleet = self.api.client.fleets.apply_plan(project_name=self.api.project, plan=plan) + except ServerClientError as e: + raise CLIError(e.msg) + if command_args.detach: + console.print("Fleet configuration submitted. Exiting...") + return + try: + with MultiItemStatus( + f"Provisioning [code]{fleet.name}[/]...", console=console + ) as live: + while not _finished_provisioning(fleet): + table = get_fleets_table([fleet]) + live.update(table) + time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS) + fleet = self.api.client.fleets.get(self.api.project, fleet.name) + except KeyboardInterrupt: + if confirm_ask("Delete the fleet before exiting?"): + with console.status("Deleting fleet..."): + self.api.client.fleets.delete( + project_name=self.api.project, names=[fleet.name] + ) + else: + console.print("Exiting... Fleet provisioning will continue in the background.") + return + console.print( + get_fleets_table( + [fleet], + verbose=_failed_provisioning(fleet), + format_date=local_time, + ) + ) + if _failed_provisioning(fleet): + console.print("\n[error]Some instances failed. Check the table above for errors.[/]") + exit(1) + def _apply_plan_on_old_server(self, plan: FleetPlan, command_args: argparse.Namespace): action_message = "" confirm_message = "" if plan.current_resource is None: @@ -86,7 +184,7 @@ def apply_configuration( diff = diff_models( old=plan.current_resource.spec.configuration, new=plan.spec.configuration, - ignore={ + reset={ "ssh_config": { "ssh_key": True, "proxy_jump": {"ssh_key"}, diff --git a/src/dstack/_internal/cli/services/profile.py b/src/dstack/_internal/cli/services/profile.py index 23bbe55ad..d57ea2e13 100644 --- a/src/dstack/_internal/cli/services/profile.py +++ b/src/dstack/_internal/cli/services/profile.py @@ -159,7 +159,7 @@ def apply_profile_args( if args.idle_duration is not None: profile_settings.idle_duration = args.idle_duration elif args.dont_destroy: - profile_settings.idle_duration = False + profile_settings.idle_duration = "off" if args.creation_policy_reuse: profile_settings.creation_policy = CreationPolicy.REUSE diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index 1deea7ee0..5eaf10ed2 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -3,7 +3,16 @@ from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType from dstack._internal.core.models.configurations import ServiceConfiguration from dstack._internal.core.models.runs import ApplyRunPlanInput, JobSpec, JobSubmission, RunSpec -from dstack._internal.server.schemas.runs import GetRunPlanRequest +from dstack._internal.server.schemas.runs import GetRunPlanRequest, ListRunsRequest + + +def get_list_runs_excludes(list_runs_request: ListRunsRequest) -> IncludeExcludeSetType: + excludes = set() + if list_runs_request.include_jobs: + excludes.add("include_jobs") + if list_runs_request.job_submissions_limit is None: + excludes.add("job_submissions_limit") + return excludes def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[IncludeExcludeDictType]: @@ -139,6 +148,8 @@ def get_job_spec_excludes(job_specs: list[JobSpec]) -> IncludeExcludeDictType: spec_excludes["repo_data"] = True if all(not s.file_archives for s in job_specs): spec_excludes["file_archives"] = True + if all(s.service_port is None for s in job_specs): + spec_excludes["service_port"] = True return spec_excludes diff --git a/src/dstack/_internal/core/compatibility/volumes.py b/src/dstack/_internal/core/compatibility/volumes.py index 7395674f9..4b7be6bb0 100644 --- a/src/dstack/_internal/core/compatibility/volumes.py +++ b/src/dstack/_internal/core/compatibility/volumes.py @@ -30,4 +30,6 @@ def _get_volume_configuration_excludes( configuration_excludes: IncludeExcludeDictType = {} if configuration.tags is None: configuration_excludes["tags"] = True + if configuration.auto_cleanup_duration is None: + configuration_excludes["auto_cleanup_duration"] = True return configuration_excludes diff --git a/src/dstack/_internal/core/models/common.py b/src/dstack/_internal/core/models/common.py index c347cf0d3..a13922671 100644 --- a/src/dstack/_internal/core/models/common.py +++ b/src/dstack/_internal/core/models/common.py @@ -1,11 +1,14 @@ import re from enum import Enum -from typing import Union +from typing import Any, Callable, Optional, Union +import orjson from pydantic import Field from pydantic_duality import DualBaseModel from typing_extensions import Annotated +from dstack._internal.utils.json_utils import pydantic_orjson_dumps + IncludeExcludeFieldType = Union[int, str] IncludeExcludeSetType = set[IncludeExcludeFieldType] IncludeExcludeDictType = dict[ @@ -20,7 +23,40 @@ # This allows to use the same model both for a strict parsing of the user input and # for a permissive parsing of the server responses. class CoreModel(DualBaseModel): - pass + class Config: + json_loads = orjson.loads + json_dumps = pydantic_orjson_dumps + + def json( + self, + *, + include: Optional[IncludeExcludeType] = None, + exclude: Optional[IncludeExcludeType] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, # ignore as it's deprecated + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + encoder: Optional[Callable[[Any], Any]] = None, + models_as_dict: bool = True, # does not seems to be needed by dstack or dependencies + **dumps_kwargs: Any, + ) -> str: + """ + Override `json()` method so that it calls `dict()`. + Allows changing how models are serialized by overriding `dict()` only. + By default, `json()` won't call `dict()`, so changes applied in `dict()` won't take place. + """ + data = self.dict( + by_alias=by_alias, + include=include, + exclude=exclude, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + if self.__custom_root_type__: + data = data["__root__"] + return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs) class Duration(int): diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 97be403ca..a22db9e36 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -4,6 +4,7 @@ from pathlib import PurePosixPath from typing import Any, Dict, List, Optional, Union +import orjson from pydantic import Field, ValidationError, conint, constr, root_validator, validator from typing_extensions import Annotated, Literal @@ -18,6 +19,9 @@ from dstack._internal.core.models.services import AnyModel, OpenAIChatModel from dstack._internal.core.models.unix import UnixUser from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point +from dstack._internal.utils.json_utils import ( + pydantic_orjson_dumps_with_indent, +) CommandsList = List[str] ValidPort = conint(gt=0, le=65536) @@ -394,8 +398,9 @@ class TaskConfiguration( class ServiceConfigurationParams(CoreModel): port: Annotated[ + # NOTE: it's a PortMapping for historical reasons. Only `port.container_port` is used. Union[ValidPort, constr(regex=r"^[0-9]+:[0-9]+$"), PortMapping], - Field(description="The port, that application listens on or the mapping"), + Field(description="The port the application listens on"), ] gateway: Annotated[ Optional[Union[bool, str]], @@ -573,6 +578,9 @@ class DstackConfiguration(CoreModel): ] class Config: + json_loads = orjson.loads + json_dumps = pydantic_orjson_dumps_with_indent + @staticmethod def schema_extra(schema: Dict[str, Any]): schema["$schema"] = "http://json-schema.org/draft-07/schema#" diff --git a/src/dstack/_internal/core/models/fleets.py b/src/dstack/_internal/core/models/fleets.py index 6cf970a95..fd616b754 100644 --- a/src/dstack/_internal/core/models/fleets.py +++ b/src/dstack/_internal/core/models/fleets.py @@ -8,7 +8,7 @@ from typing_extensions import Annotated, Literal from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.common import ApplyAction, CoreModel from dstack._internal.core.models.envs import Env from dstack._internal.core.models.instances import Instance, InstanceOfferWithAvailability, SSHKey from dstack._internal.core.models.profiles import ( @@ -324,6 +324,7 @@ class FleetPlan(CoreModel): offers: List[InstanceOfferWithAvailability] total_offers: int max_offer_price: Optional[float] = None + action: Optional[ApplyAction] = None # default value for backward compatibility def get_effective_spec(self) -> FleetSpec: if self.effective_spec is not None: diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index 62997ce4e..78199608c 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -1,12 +1,14 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union, overload +import orjson from pydantic import Field, root_validator, validator from typing_extensions import Annotated, Literal from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel, Duration from dstack._internal.utils.common import list_enum_values_for_annotation +from dstack._internal.utils.json_utils import pydantic_orjson_dumps_with_indent from dstack._internal.utils.tags import tags_validator DEFAULT_RETRY_DURATION = 3600 @@ -74,11 +76,9 @@ def parse_off_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[str return parse_duration(v) -def parse_idle_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[str, int, bool]]: - if v is False: +def parse_idle_duration(v: Optional[Union[int, str]]) -> Optional[Union[str, int]]: + if v == "off" or v == -1: return -1 - if v is True: - return None return parse_duration(v) @@ -249,7 +249,7 @@ class ProfileParams(CoreModel): ), ] = None idle_duration: Annotated[ - Optional[Union[Literal["off"], str, int, bool]], + Optional[Union[Literal["off"], str, int]], Field( description=( "Time to wait before terminating idle instances." @@ -343,6 +343,9 @@ class ProfilesConfig(CoreModel): profiles: List[Profile] class Config: + json_loads = orjson.loads + json_dumps = pydantic_orjson_dumps_with_indent + schema_extra = {"$schema": "http://json-schema.org/draft-07/schema#"} def default(self) -> Optional[Profile]: diff --git a/src/dstack/_internal/core/models/resources.py b/src/dstack/_internal/core/models/resources.py index a0ecff953..15c80f716 100644 --- a/src/dstack/_internal/core/models/resources.py +++ b/src/dstack/_internal/core/models/resources.py @@ -382,14 +382,6 @@ def schema_extra(schema: Dict[str, Any]): gpu: Annotated[Optional[GPUSpec], Field(description="The GPU requirements")] = None disk: Annotated[Optional[DiskSpec], Field(description="The disk resources")] = DEFAULT_DISK - # TODO: Remove in 0.20. Added for backward compatibility. - @root_validator - def _post_validate(cls, values): - cpu = values.get("cpu") - if isinstance(cpu, CPUSpec) and cpu.arch in [None, gpuhunt.CPUArchitecture.X86]: - values["cpu"] = cpu.count - return values - def pretty_format(self) -> str: # TODO: Remove in 0.20. Use self.cpu directly cpu = parse_obj_as(CPUSpec, self.cpu) @@ -407,3 +399,18 @@ def pretty_format(self) -> str: resources.update(disk_size=self.disk.size) res = pretty_resources(**resources) return res + + def dict(self, *args, **kwargs) -> Dict: + # super() does not work with pydantic-duality + res = CoreModel.dict(self, *args, **kwargs) + self._update_serialized_cpu(res) + return res + + # TODO: Remove in 0.20. Added for backward compatibility. + def _update_serialized_cpu(self, values: Dict): + cpu = values["cpu"] + if cpu: + arch = cpu.get("arch") + count = cpu.get("count") + if count and arch in [None, gpuhunt.CPUArchitecture.X86.value]: + values["cpu"] = count diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 49691eb50..3f67646ff 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -11,6 +11,7 @@ DEFAULT_REPO_DIR, AnyRunConfiguration, RunConfiguration, + ServiceConfiguration, ) from dstack._internal.core.models.files import FileArchiveMapping from dstack._internal.core.models.instances import ( @@ -101,6 +102,14 @@ def to_status(self) -> "RunStatus": } return mapping[self] + def to_error(self) -> Optional[str]: + if self == RunTerminationReason.RETRY_LIMIT_EXCEEDED: + return "retry limit exceeded" + elif self == RunTerminationReason.SERVER_ERROR: + return "server error" + else: + return None + class JobTerminationReason(str, Enum): # Set by the server @@ -162,6 +171,24 @@ def to_retry_event(self) -> Optional[RetryEvent]: default = RetryEvent.ERROR if self.to_status() == JobStatus.FAILED else None return mapping.get(self, default) + def to_error(self) -> Optional[str]: + # Should return None for values that are already + # handled and shown in status_message. + error_mapping = { + JobTerminationReason.INSTANCE_UNREACHABLE: "instance unreachable", + JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED: "waiting instance limit exceeded", + JobTerminationReason.VOLUME_ERROR: "volume error", + JobTerminationReason.GATEWAY_ERROR: "gateway error", + JobTerminationReason.SCALED_DOWN: "scaled down", + JobTerminationReason.INACTIVITY_DURATION_EXCEEDED: "inactivity duration exceeded", + JobTerminationReason.TERMINATED_DUE_TO_UTILIZATION_POLICY: "utilization policy", + JobTerminationReason.PORTS_BINDING_FAILED: "ports binding failed", + JobTerminationReason.CREATING_CONTAINER_ERROR: "runner error", + JobTerminationReason.EXECUTOR_ERROR: "executor error", + JobTerminationReason.MAX_DURATION_EXCEEDED: "max duration exceeded", + } + return error_mapping.get(self) + class Requirements(CoreModel): # TODO: Make requirements' fields required @@ -227,6 +254,8 @@ class JobSpec(CoreModel): # TODO: drop this comment when supporting jobs submitted before 0.19.17 is no longer relevant. repo_code_hash: Optional[str] = None file_archives: list[FileArchiveMapping] = [] + # None for non-services and pre-0.19.19 services. See `get_service_port` + service_port: Optional[int] = None class JobProvisioningData(CoreModel): @@ -305,13 +334,12 @@ class JobSubmission(CoreModel): finished_at: Optional[datetime] inactivity_secs: Optional[int] status: JobStatus + status_message: str = "" # default for backward compatibility termination_reason: Optional[JobTerminationReason] termination_reason_message: Optional[str] exit_status: Optional[int] job_provisioning_data: Optional[JobProvisioningData] job_runtime_data: Optional[JobRuntimeData] - # TODO: make status_message and error a computed field after migrating to pydanticV2 - status_message: Optional[str] = None error: Optional[str] = None @property @@ -325,71 +353,6 @@ def duration(self) -> timedelta: end_time = self.finished_at return end_time - self.submitted_at - @root_validator - def _status_message(cls, values) -> Dict: - try: - status = values["status"] - termination_reason = values["termination_reason"] - exit_code = values["exit_status"] - except KeyError: - return values - values["status_message"] = JobSubmission._get_status_message( - status=status, - termination_reason=termination_reason, - exit_status=exit_code, - ) - return values - - @staticmethod - def _get_status_message( - status: JobStatus, - termination_reason: Optional[JobTerminationReason], - exit_status: Optional[int], - ) -> str: - if status == JobStatus.DONE: - return "exited (0)" - elif status == JobStatus.FAILED: - if termination_reason == JobTerminationReason.CONTAINER_EXITED_WITH_ERROR: - return f"exited ({exit_status})" - elif termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY: - return "no offers" - elif termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY: - return "interrupted" - else: - return "error" - elif status == JobStatus.TERMINATED: - if termination_reason == JobTerminationReason.TERMINATED_BY_USER: - return "stopped" - elif termination_reason == JobTerminationReason.ABORTED_BY_USER: - return "aborted" - return status.value - - @root_validator - def _error(cls, values) -> Dict: - try: - termination_reason = values["termination_reason"] - except KeyError: - return values - values["error"] = JobSubmission._get_error(termination_reason=termination_reason) - return values - - @staticmethod - def _get_error(termination_reason: Optional[JobTerminationReason]) -> Optional[str]: - error_mapping = { - JobTerminationReason.INSTANCE_UNREACHABLE: "instance unreachable", - JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED: "waiting instance limit exceeded", - JobTerminationReason.VOLUME_ERROR: "volume error", - JobTerminationReason.GATEWAY_ERROR: "gateway error", - JobTerminationReason.SCALED_DOWN: "scaled down", - JobTerminationReason.INACTIVITY_DURATION_EXCEEDED: "inactivity duration exceeded", - JobTerminationReason.TERMINATED_DUE_TO_UTILIZATION_POLICY: "utilization policy", - JobTerminationReason.PORTS_BINDING_FAILED: "ports binding failed", - JobTerminationReason.CREATING_CONTAINER_ERROR: "runner error", - JobTerminationReason.EXECUTOR_ERROR: "executor error", - JobTerminationReason.MAX_DURATION_EXCEEDED: "max duration exceeded", - } - return error_mapping.get(termination_reason) - class Job(CoreModel): job_spec: JobSpec @@ -524,85 +487,17 @@ class Run(CoreModel): submitted_at: datetime last_processed_at: datetime status: RunStatus - status_message: Optional[str] = None - termination_reason: Optional[RunTerminationReason] + status_message: str = "" # default for backward compatibility + termination_reason: Optional[RunTerminationReason] = None run_spec: RunSpec jobs: List[Job] - latest_job_submission: Optional[JobSubmission] + latest_job_submission: Optional[JobSubmission] = None cost: float = 0 service: Optional[ServiceSpec] = None deployment_num: int = 0 # default for compatibility with pre-0.19.14 servers - # TODO: make error a computed field after migrating to pydanticV2 error: Optional[str] = None deleted: Optional[bool] = None - @root_validator - def _error(cls, values) -> Dict: - try: - termination_reason = values["termination_reason"] - except KeyError: - return values - values["error"] = Run._get_error(termination_reason=termination_reason) - return values - - @staticmethod - def _get_error(termination_reason: Optional[RunTerminationReason]) -> Optional[str]: - if termination_reason == RunTerminationReason.RETRY_LIMIT_EXCEEDED: - return "retry limit exceeded" - elif termination_reason == RunTerminationReason.SERVER_ERROR: - return "server error" - else: - return None - - @root_validator - def _status_message(cls, values) -> Dict: - try: - status = values["status"] - jobs: List[Job] = values["jobs"] - retry_on_events = ( - jobs[0].job_spec.retry.on_events if jobs and jobs[0].job_spec.retry else [] - ) - job_status = ( - jobs[0].job_submissions[-1].status - if len(jobs) == 1 and jobs[0].job_submissions - else None - ) - termination_reason = Run.get_last_termination_reason(jobs[0]) if jobs else None - except KeyError: - return values - values["status_message"] = Run._get_status_message( - status=status, - job_status=job_status, - retry_on_events=retry_on_events, - termination_reason=termination_reason, - ) - return values - - @staticmethod - def get_last_termination_reason(job: "Job") -> Optional[JobTerminationReason]: - for submission in reversed(job.job_submissions): - if submission.termination_reason is not None: - return submission.termination_reason - return None - - @staticmethod - def _get_status_message( - status: RunStatus, - job_status: Optional[JobStatus], - retry_on_events: List[RetryEvent], - termination_reason: Optional[JobTerminationReason], - ) -> str: - if job_status == JobStatus.PULLING: - return "pulling" - # Currently, `retrying` is shown only for `no-capacity` events - if ( - status in [RunStatus.SUBMITTED, RunStatus.PENDING] - and termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY - and RetryEvent.NO_CAPACITY in retry_on_events - ): - return "retrying" - return status.value - def is_deployment_in_progress(self) -> bool: return any( not j.job_submissions[-1].status.is_finished() @@ -658,3 +553,11 @@ def get_policy_map(spot_policy: Optional[SpotPolicy], default: SpotPolicy) -> Op SpotPolicy.ONDEMAND: False, } return policy_map[spot_policy] + + +def get_service_port(job_spec: JobSpec, configuration: ServiceConfiguration) -> int: + # Compatibility with pre-0.19.19 job specs that do not have the `service_port` property. + # TODO: drop when pre-0.19.19 jobs are no longer relevant. + if job_spec.service_port is None: + return configuration.port.container_port + return job_spec.service_port diff --git a/src/dstack/_internal/core/models/volumes.py b/src/dstack/_internal/core/models/volumes.py index 773fd9429..0f89b770f 100644 --- a/src/dstack/_internal/core/models/volumes.py +++ b/src/dstack/_internal/core/models/volumes.py @@ -9,6 +9,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.profiles import parse_idle_duration from dstack._internal.core.models.resources import Memory from dstack._internal.utils.common import get_or_error from dstack._internal.utils.tags import tags_validator @@ -44,6 +45,16 @@ class VolumeConfiguration(CoreModel): Optional[str], Field(description="The volume ID. Must be specified when registering external volumes"), ] = None + auto_cleanup_duration: Annotated[ + Optional[Union[str, int]], + Field( + description=( + "Time to wait after volume is no longer used by any job before deleting it. " + "Defaults to keep the volume indefinitely. " + "Use the value 'off' or -1 to disable auto-cleanup." + ) + ), + ] = None tags: Annotated[ Optional[Dict[str, str]], Field( @@ -56,6 +67,9 @@ class VolumeConfiguration(CoreModel): ] = None _validate_tags = validator("tags", pre=True, allow_reuse=True)(tags_validator) + _validate_auto_cleanup_duration = validator( + "auto_cleanup_duration", pre=True, allow_reuse=True + )(parse_idle_duration) @property def size_gb(self) -> int: diff --git a/src/dstack/_internal/core/services/diff.py b/src/dstack/_internal/core/services/diff.py index d50ab90e5..0d63cebc4 100644 --- a/src/dstack/_internal/core/services/diff.py +++ b/src/dstack/_internal/core/services/diff.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, TypedDict +from typing import Any, Optional, TypedDict, TypeVar from pydantic import BaseModel @@ -15,20 +15,19 @@ class ModelFieldDiff(TypedDict): # TODO: calculate nested diffs def diff_models( - old: BaseModel, new: BaseModel, ignore: Optional[IncludeExcludeType] = None + old: BaseModel, new: BaseModel, reset: Optional[IncludeExcludeType] = None ) -> ModelDiff: """ Returns a diff of model instances fields. - NOTE: `ignore` is implemented as `BaseModel.parse_obj(BaseModel.dict(exclude=ignore))`, - that is, the "ignored" fields are actually not ignored but reset to the default values - before comparison, meaning that 1) any field in `ignore` must have a default value, - 2) the default value must be equal to itself (e.g. `math.nan` != `math.nan`). + The fields specified in the `reset` option are reset to their default values, effectively + excluding them from comparison (assuming that the default value is equal to itself, e.g, + `None == None`, `"task" == "task"`, but `math.nan != math.nan`). Args: old: The "old" model instance. new: The "new" model instance. - ignore: Optional fields to ignore. + reset: Fields to reset to their default values before comparison. Returns: A dict of changed fields in the form of @@ -37,9 +36,9 @@ def diff_models( if type(old) is not type(new): raise TypeError("Both instances must be of the same Pydantic model class.") - if ignore is not None: - old = type(old).parse_obj(old.dict(exclude=ignore)) - new = type(new).parse_obj(new.dict(exclude=ignore)) + if reset is not None: + old = copy_model(old, reset=reset) + new = copy_model(new, reset=reset) changes: ModelDiff = {} for field in old.__fields__: @@ -49,3 +48,24 @@ def diff_models( changes[field] = {"old": old_value, "new": new_value} return changes + + +M = TypeVar("M", bound=BaseModel) + + +def copy_model(model: M, reset: Optional[IncludeExcludeType] = None) -> M: + """ + Returns a deep copy of the model instance. + + Implemented as `BaseModel.parse_obj(BaseModel.dict())`, thus, + unlike `BaseModel.copy(deep=True)`, runs all validations. + + The fields specified in the `reset` option are reset to their default values. + + Args: + reset: Fields to reset to their default values. + + Returns: + A deep copy of the model instance. + """ + return type(model).parse_obj(model.dict(exclude=reset)) diff --git a/src/dstack/_internal/core/services/ssh/attach.py b/src/dstack/_internal/core/services/ssh/attach.py index a095fafb2..d0ad4ac64 100644 --- a/src/dstack/_internal/core/services/ssh/attach.py +++ b/src/dstack/_internal/core/services/ssh/attach.py @@ -64,6 +64,7 @@ def __init__( run_name: str, dockerized: bool, ssh_proxy: Optional[SSHConnectionParams] = None, + service_port: Optional[int] = None, local_backend: bool = False, bind_address: Optional[str] = None, ): @@ -90,6 +91,7 @@ def __init__( }, ) self.ssh_proxy = ssh_proxy + self.service_port = service_port hosts: dict[str, dict[str, Union[str, int, FilePath]]] = {} self.hosts = hosts diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 7435ff8cb..8aff963c1 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -10,7 +10,7 @@ import sentry_sdk from fastapi import FastAPI, Request, Response, status from fastapi.datastructures import URL -from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse +from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from prometheus_client import Counter, Histogram @@ -56,6 +56,7 @@ ) from dstack._internal.server.utils.logging import configure_logging from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, check_client_server_compatibility, error_detail, get_server_client_error_details, @@ -90,7 +91,10 @@ def create_app() -> FastAPI: profiles_sample_rate=settings.SENTRY_PROFILES_SAMPLE_RATE, ) - app = FastAPI(docs_url="/api/docs", lifespan=lifespan) + app = FastAPI( + docs_url="/api/docs", + lifespan=lifespan, + ) app.state.proxy_dependency_injector = ServerProxyDependencyInjector() return app @@ -147,7 +151,10 @@ async def lifespan(app: FastAPI): ) if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None: init_default_storage() - scheduler = start_background_tasks() + if settings.SERVER_BACKGROUND_PROCESSING_ENABLED: + scheduler = start_background_tasks() + else: + logger.info("Background processing is disabled") dstack_version = DSTACK_VERSION if DSTACK_VERSION else "(no version)" logger.info(f"The admin token is {admin.token.get_plaintext_or_error()}", {"show_path": False}) logger.info( @@ -157,7 +164,8 @@ async def lifespan(app: FastAPI): for func in _ON_STARTUP_HOOKS: await func(app) yield - scheduler.shutdown() + if settings.SERVER_BACKGROUND_PROCESSING_ENABLED: + scheduler.shutdown() await gateway_connections_pool.remove_all() service_conn_pool = await get_injector_from_app(app).get_service_connection_pool() await service_conn_pool.remove_all() @@ -208,14 +216,14 @@ async def forbidden_error_handler(request: Request, exc: ForbiddenError): msg = "Access denied" if len(exc.args) > 0: msg = exc.args[0] - return JSONResponse( + return CustomORJSONResponse( status_code=status.HTTP_403_FORBIDDEN, content=error_detail(msg), ) @app.exception_handler(ServerClientError) async def server_client_error_handler(request: Request, exc: ServerClientError): - return JSONResponse( + return CustomORJSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"detail": get_server_client_error_details(exc)}, ) @@ -223,7 +231,7 @@ async def server_client_error_handler(request: Request, exc: ServerClientError): @app.exception_handler(OSError) async def os_error_handler(request, exc: OSError): if exc.errno in [36, 63]: - return JSONResponse( + return CustomORJSONResponse( {"detail": "Filename too long"}, status_code=status.HTTP_400_BAD_REQUEST, ) @@ -309,7 +317,7 @@ async def check_client_version(request: Request, call_next): @app.get("/healthcheck") async def healthcheck(): - return JSONResponse(content={"status": "running"}) + return CustomORJSONResponse(content={"status": "running"}) if ui and Path(__file__).parent.joinpath("statics").exists(): app.mount( @@ -323,7 +331,7 @@ async def custom_http_exception_handler(request, exc): or _is_proxy_request(request) or _is_prometheus_request(request) ): - return JSONResponse( + return CustomORJSONResponse( {"detail": exc.detail}, status_code=status.HTTP_404_NOT_FOUND, ) diff --git a/src/dstack/_internal/server/background/__init__.py b/src/dstack/_internal/server/background/__init__.py index 2dd410cd2..ec927ecfc 100644 --- a/src/dstack/_internal/server/background/__init__.py +++ b/src/dstack/_internal/server/background/__init__.py @@ -4,9 +4,10 @@ from dstack._internal.server import settings from dstack._internal.server.background.tasks.process_fleets import process_fleets from dstack._internal.server.background.tasks.process_gateways import ( + process_gateways, process_gateways_connections, - process_submitted_gateways, ) +from dstack._internal.server.background.tasks.process_idle_volumes import process_idle_volumes from dstack._internal.server.background.tasks.process_instances import ( process_instances, ) @@ -70,11 +71,12 @@ def start_background_tasks() -> AsyncIOScheduler: ) _scheduler.add_job(delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1) _scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15)) + _scheduler.add_job(process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5) _scheduler.add_job( - process_submitted_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5 + process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5 ) _scheduler.add_job( - process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5 + process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 ) _scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30, jitter=5)) for replica in range(settings.SERVER_BACKGROUND_PROCESSING_FACTOR): diff --git a/src/dstack/_internal/server/background/tasks/process_gateways.py b/src/dstack/_internal/server/background/tasks/process_gateways.py index e2a17aa15..cd6025a1a 100644 --- a/src/dstack/_internal/server/background/tasks/process_gateways.py +++ b/src/dstack/_internal/server/background/tasks/process_gateways.py @@ -16,6 +16,7 @@ gateway_connections_pool, ) from dstack._internal.server.services.locking import advisory_lock_ctx, get_locker +from dstack._internal.server.services.logging import fmt from dstack._internal.utils.common import get_current_datetime from dstack._internal.utils.logging import get_logger @@ -27,14 +28,14 @@ async def process_gateways_connections(): await _process_active_connections() -async def process_submitted_gateways(): +async def process_gateways(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(GatewayModel.__tablename__) async with get_session_ctx() as session: async with lock: res = await session.execute( select(GatewayModel) .where( - GatewayModel.status == GatewayStatus.SUBMITTED, + GatewayModel.status.in_([GatewayStatus.SUBMITTED, GatewayStatus.PROVISIONING]), GatewayModel.id.not_in(lockset), ) .options(lazyload(GatewayModel.gateway_compute)) @@ -48,7 +49,25 @@ async def process_submitted_gateways(): lockset.add(gateway_model.id) try: gateway_model_id = gateway_model.id - await _process_submitted_gateway(session=session, gateway_model=gateway_model) + initial_status = gateway_model.status + if initial_status == GatewayStatus.SUBMITTED: + await _process_submitted_gateway(session=session, gateway_model=gateway_model) + elif initial_status == GatewayStatus.PROVISIONING: + await _process_provisioning_gateway(session=session, gateway_model=gateway_model) + else: + logger.error( + "%s: unexpected gateway status %r", fmt(gateway_model), initial_status.upper() + ) + if gateway_model.status != initial_status: + logger.info( + "%s: gateway status has changed %s -> %s%s", + fmt(gateway_model), + initial_status.upper(), + gateway_model.status.upper(), + f": {gateway_model.status_message}" if gateway_model.status_message else "", + ) + gateway_model.last_processed_at = get_current_datetime() + await session.commit() finally: lockset.difference_update([gateway_model_id]) @@ -89,7 +108,7 @@ async def _process_connection(conn: GatewayConnection): async def _process_submitted_gateway(session: AsyncSession, gateway_model: GatewayModel): - logger.info("Started gateway %s provisioning", gateway_model.name) + logger.info("%s: started gateway provisioning", fmt(gateway_model)) # Refetch to load related attributes. # joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE. res = await session.execute( @@ -110,8 +129,6 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew except BackendNotAvailable: gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = "Backend not available" - gateway_model.last_processed_at = get_current_datetime() - await session.commit() return try: @@ -123,53 +140,54 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew ) session.add(gateway_model) gateway_model.status = GatewayStatus.PROVISIONING - await session.commit() - await session.refresh(gateway_model) except BackendError as e: - logger.info( - "Failed to create gateway compute for gateway %s: %s", gateway_model.name, repr(e) - ) + logger.info("%s: failed to create gateway compute: %r", fmt(gateway_model), e) gateway_model.status = GatewayStatus.FAILED status_message = f"Backend error: {repr(e)}" if len(e.args) > 0: status_message = str(e.args[0]) gateway_model.status_message = status_message - gateway_model.last_processed_at = get_current_datetime() - await session.commit() - return except Exception as e: - logger.exception( - "Got exception when creating gateway compute for gateway %s", gateway_model.name - ) + logger.exception("%s: got exception when creating gateway compute", fmt(gateway_model)) gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = f"Unexpected error: {repr(e)}" - gateway_model.last_processed_at = get_current_datetime() - await session.commit() - return + +async def _process_provisioning_gateway( + session: AsyncSession, gateway_model: GatewayModel +) -> None: + # Refetch to load related attributes. + # joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE. + res = await session.execute( + select(GatewayModel) + .where(GatewayModel.id == gateway_model.id) + .execution_options(populate_existing=True) + ) + gateway_model = res.unique().scalar_one() + + # FIXME: problems caused by blocking on connect_to_gateway_with_retry and configure_gateway: + # - cannot delete the gateway before it is provisioned because the DB model is locked + # - connection retry counter is reset on server restart + # - only one server replica is processing the gateway + # Easy to fix by doing only one connection/configuration attempt per processing iteration. The + # main challenge is applying the same provisioning model to the dstack Sky gateway to avoid + # maintaining a different model for Sky. connection = await gateways_services.connect_to_gateway_with_retry( gateway_model.gateway_compute ) if connection is None: gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = "Failed to connect to gateway" - gateway_model.last_processed_at = get_current_datetime() gateway_model.gateway_compute.deleted = True - await session.commit() return - try: await gateways_services.configure_gateway(connection) except Exception: - logger.exception("Failed to configure gateway %s", gateway_model.name) + logger.exception("%s: failed to configure gateway", fmt(gateway_model)) gateway_model.status = GatewayStatus.FAILED gateway_model.status_message = "Failed to configure gateway" - gateway_model.last_processed_at = get_current_datetime() await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address) gateway_model.gateway_compute.active = False - await session.commit() return gateway_model.status = GatewayStatus.RUNNING - gateway_model.last_processed_at = get_current_datetime() - await session.commit() diff --git a/src/dstack/_internal/server/background/tasks/process_idle_volumes.py b/src/dstack/_internal/server/background/tasks/process_idle_volumes.py new file mode 100644 index 000000000..33d9d5a9b --- /dev/null +++ b/src/dstack/_internal/server/background/tasks/process_idle_volumes.py @@ -0,0 +1,139 @@ +import datetime +from typing import List + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport +from dstack._internal.core.errors import BackendNotAvailable +from dstack._internal.core.models.profiles import parse_duration +from dstack._internal.core.models.volumes import VolumeStatus +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ProjectModel, VolumeModel +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.volumes import ( + get_volume_configuration, + volume_model_to_volume, +) +from dstack._internal.utils import common +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +async def process_idle_volumes(): + lock, lockset = get_locker(get_db().dialect_name).get_lockset(VolumeModel.__tablename__) + async with get_session_ctx() as session: + async with lock: + res = await session.execute( + select(VolumeModel.id) + .where( + VolumeModel.status == VolumeStatus.ACTIVE, + VolumeModel.deleted == False, + VolumeModel.id.not_in(lockset), + ) + .order_by(VolumeModel.last_processed_at.asc()) + .limit(10) + .with_for_update(skip_locked=True, key_share=True) + ) + volume_ids = list(res.scalars().all()) + if not volume_ids: + return + for volume_id in volume_ids: + lockset.add(volume_id) + + res = await session.execute( + select(VolumeModel) + .where(VolumeModel.id.in_(volume_ids)) + .options(joinedload(VolumeModel.project).joinedload(ProjectModel.backends)) + .options(joinedload(VolumeModel.user)) + .options(joinedload(VolumeModel.attachments)) + .execution_options(populate_existing=True) + ) + volume_models = list(res.unique().scalars().all()) + try: + volumes_to_delete = [v for v in volume_models if _should_delete_volume(v)] + if not volumes_to_delete: + return + await _delete_idle_volumes(session, volumes_to_delete) + finally: + lockset.difference_update(volume_ids) + + +def _should_delete_volume(volume: VolumeModel) -> bool: + if volume.attachments: + return False + + config = get_volume_configuration(volume) + if not config.auto_cleanup_duration: + return False + + duration_seconds = parse_duration(config.auto_cleanup_duration) + if not duration_seconds or duration_seconds <= 0: + return False + + idle_time = _get_idle_time(volume) + threshold = datetime.timedelta(seconds=duration_seconds) + return idle_time > threshold + + +def _get_idle_time(volume: VolumeModel) -> datetime.timedelta: + last_used = volume.last_job_processed_at or volume.created_at + last_used_utc = last_used.replace(tzinfo=datetime.timezone.utc) + idle_time = get_current_datetime() - last_used_utc + return max(idle_time, datetime.timedelta(0)) + + +async def _delete_idle_volumes(session: AsyncSession, volumes: List[VolumeModel]): + # Note: Multiple volumes are deleted in the same transaction, + # so long deletion of one volume may block processing other volumes. + for volume_model in volumes: + logger.info("Deleting idle volume %s", volume_model.name) + try: + await _delete_idle_volume(session, volume_model) + except Exception: + logger.exception("Error when deleting idle volume %s", volume_model.name) + + volume_model.deleted = True + volume_model.deleted_at = get_current_datetime() + + logger.info("Deleted idle volume %s", volume_model.name) + + await session.commit() + + +async def _delete_idle_volume(session: AsyncSession, volume_model: VolumeModel): + volume = volume_model_to_volume(volume_model) + + if volume.provisioning_data is None: + logger.error( + f"Failed to delete volume {volume_model.name}. volume.provisioning_data is None." + ) + return + + if volume.provisioning_data.backend is None: + logger.error( + f"Failed to delete volume {volume_model.name}. volume.provisioning_data.backend is None." + ) + return + + try: + backend = await backends_services.get_project_backend_by_type_or_error( + project=volume_model.project, + backend_type=volume.provisioning_data.backend, + ) + except BackendNotAvailable: + logger.error( + f"Failed to delete volume {volume_model.name}. Backend {volume.configuration.backend} not available." + ) + return + + compute = backend.compute() + assert isinstance(compute, ComputeWithVolumeSupport) + await common.run_async( + compute.delete_volume, + volume=volume, + ) diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index b9c2f9c94..eba9549b6 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -739,3 +739,5 @@ async def _attach_volume( attachment_data=attachment_data.json(), ) instance.volume_attachments.append(volume_attachment_model) + + volume_model.last_job_processed_at = common_utils.get_current_datetime() diff --git a/src/dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py b/src/dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py index a0cac48af..e26206222 100644 --- a/src/dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py +++ b/src/dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py @@ -17,12 +17,6 @@ def upgrade() -> None: - with op.batch_alter_table("jobs", schema=None) as batch_op: - batch_op.add_column(sa.Column("deployment_num", sa.Integer(), nullable=True)) - with op.batch_alter_table("jobs", schema=None) as batch_op: - batch_op.execute("UPDATE jobs SET deployment_num = 0") - batch_op.alter_column("deployment_num", nullable=False) - with op.batch_alter_table("runs", schema=None) as batch_op: batch_op.add_column(sa.Column("deployment_num", sa.Integer(), nullable=True)) batch_op.add_column(sa.Column("desired_replica_count", sa.Integer(), nullable=True)) @@ -32,6 +26,12 @@ def upgrade() -> None: batch_op.alter_column("deployment_num", nullable=False) batch_op.alter_column("desired_replica_count", nullable=False) + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.add_column(sa.Column("deployment_num", sa.Integer(), nullable=True)) + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.execute("UPDATE jobs SET deployment_num = 0") + batch_op.alter_column("deployment_num", nullable=False) + def downgrade() -> None: with op.batch_alter_table("runs", schema=None) as batch_op: diff --git a/src/dstack/_internal/server/migrations/versions/d5863798bf41_add_volumemodel_last_job_processed_at.py b/src/dstack/_internal/server/migrations/versions/d5863798bf41_add_volumemodel_last_job_processed_at.py new file mode 100644 index 000000000..1dc883e05 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/d5863798bf41_add_volumemodel_last_job_processed_at.py @@ -0,0 +1,40 @@ +"""Add VolumeModel.last_job_processed_at + +Revision ID: d5863798bf41 +Revises: 644b8a114187 +Create Date: 2025-07-15 14:26:22.981687 + +""" + +import sqlalchemy as sa +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "d5863798bf41" +down_revision = "644b8a114187" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("volumes", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "last_job_processed_at", + dstack._internal.server.models.NaiveDateTime(), + nullable=True, + ) + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("volumes", schema=None) as batch_op: + batch_op.drop_column("last_job_processed_at") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index d39d07be1..c4dafe81e 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -645,6 +645,7 @@ class VolumeModel(BaseModel): last_processed_at: Mapped[datetime] = mapped_column( NaiveDateTime, default=get_current_datetime ) + last_job_processed_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) deleted: Mapped[bool] = mapped_column(Boolean, default=False) deleted_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) diff --git a/src/dstack/_internal/server/routers/backends.py b/src/dstack/_internal/server/routers/backends.py index 7b6056b92..b43463a90 100644 --- a/src/dstack/_internal/server/routers/backends.py +++ b/src/dstack/_internal/server/routers/backends.py @@ -27,7 +27,10 @@ get_backend_config_yaml, update_backend_config_yaml, ) -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) root_router = APIRouter( prefix="/api/backends", @@ -41,35 +44,37 @@ ) -@root_router.post("/list_types") -async def list_backend_types() -> List[BackendType]: - return dstack._internal.core.backends.configurators.list_available_backend_types() +@root_router.post("/list_types", response_model=List[BackendType]) +async def list_backend_types(): + return CustomORJSONResponse( + dstack._internal.core.backends.configurators.list_available_backend_types() + ) -@project_router.post("/create") +@project_router.post("/create", response_model=AnyBackendConfigWithCreds) async def create_backend( body: AnyBackendConfigWithCreds, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> AnyBackendConfigWithCreds: +): _, project = user_project config = await backends.create_backend(session=session, project=project, config=body) if settings.SERVER_CONFIG_ENABLED: await ServerConfigManager().sync_config(session=session) - return config + return CustomORJSONResponse(config) -@project_router.post("/update") +@project_router.post("/update", response_model=AnyBackendConfigWithCreds) async def update_backend( body: AnyBackendConfigWithCreds, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> AnyBackendConfigWithCreds: +): _, project = user_project config = await backends.update_backend(session=session, project=project, config=body) if settings.SERVER_CONFIG_ENABLED: await ServerConfigManager().sync_config(session=session) - return config + return CustomORJSONResponse(config) @project_router.post("/delete") @@ -86,16 +91,16 @@ async def delete_backends( await ServerConfigManager().sync_config(session=session) -@project_router.post("/{backend_name}/config_info") +@project_router.post("/{backend_name}/config_info", response_model=AnyBackendConfigWithCreds) async def get_backend_config_info( backend_name: BackendType, user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> AnyBackendConfigWithCreds: +): _, project = user_project config = await backends.get_backend_config(project=project, backend_type=backend_name) if config is None: raise ResourceNotExistsError() - return config + return CustomORJSONResponse(config) @project_router.post("/create_yaml") @@ -126,10 +131,12 @@ async def update_backend_yaml( ) -@project_router.post("/{backend_name}/get_yaml") +@project_router.post("/{backend_name}/get_yaml", response_model=BackendInfoYAML) async def get_backend_yaml( backend_name: BackendType, user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> BackendInfoYAML: +): _, project = user_project - return await get_backend_config_yaml(project=project, backend_type=backend_name) + return CustomORJSONResponse( + await get_backend_config_yaml(project=project, backend_type=backend_name) + ) diff --git a/src/dstack/_internal/server/routers/files.py b/src/dstack/_internal/server/routers/files.py index 574ef0177..ff7b2a3d5 100644 --- a/src/dstack/_internal/server/routers/files.py +++ b/src/dstack/_internal/server/routers/files.py @@ -12,6 +12,7 @@ from dstack._internal.server.services import files from dstack._internal.server.settings import SERVER_CODE_UPLOAD_LIMIT from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, get_base_api_additional_responses, get_request_size, ) @@ -24,12 +25,12 @@ ) -@router.post("/get_archive_by_hash") +@router.post("/get_archive_by_hash", response_model=FileArchive) async def get_archive_by_hash( body: GetFileArchiveByHashRequest, session: Annotated[AsyncSession, Depends(get_session)], user: Annotated[UserModel, Depends(Authenticated())], -) -> FileArchive: +): archive = await files.get_archive_by_hash( session=session, user=user, @@ -37,16 +38,16 @@ async def get_archive_by_hash( ) if archive is None: raise ResourceNotExistsError() - return archive + return CustomORJSONResponse(archive) -@router.post("/upload_archive") +@router.post("/upload_archive", response_model=FileArchive) async def upload_archive( request: Request, file: UploadFile, session: Annotated[AsyncSession, Depends(get_session)], user: Annotated[UserModel, Depends(Authenticated())], -) -> FileArchive: +): request_size = get_request_size(request) if SERVER_CODE_UPLOAD_LIMIT > 0 and request_size > SERVER_CODE_UPLOAD_LIMIT: diff_size_fmt = sizeof_fmt(request_size) @@ -64,4 +65,4 @@ async def upload_archive( user=user, file=file, ) - return archive + return CustomORJSONResponse(archive) diff --git a/src/dstack/_internal/server/routers/fleets.py b/src/dstack/_internal/server/routers/fleets.py index 92aca9e18..3cbab9508 100644 --- a/src/dstack/_internal/server/routers/fleets.py +++ b/src/dstack/_internal/server/routers/fleets.py @@ -18,7 +18,10 @@ ListFleetsRequest, ) from dstack._internal.server.security.permissions import Authenticated, ProjectMember -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) root_router = APIRouter( prefix="/api/fleets", @@ -32,12 +35,12 @@ ) -@root_router.post("/list") +@root_router.post("/list", response_model=List[Fleet]) async def list_fleets( body: ListFleetsRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> List[Fleet]: +): """ Returns all fleets and instances within them visible to user sorted by descending `created_at`. `project_name` and `only_active` can be specified as filters. @@ -45,36 +48,40 @@ async def list_fleets( The results are paginated. To get the next page, pass `created_at` and `id` of the last fleet from the previous page as `prev_created_at` and `prev_id`. """ - return await fleets_services.list_fleets( - session=session, - user=user, - project_name=body.project_name, - only_active=body.only_active, - prev_created_at=body.prev_created_at, - prev_id=body.prev_id, - limit=body.limit, - ascending=body.ascending, + return CustomORJSONResponse( + await fleets_services.list_fleets( + session=session, + user=user, + project_name=body.project_name, + only_active=body.only_active, + prev_created_at=body.prev_created_at, + prev_id=body.prev_id, + limit=body.limit, + ascending=body.ascending, + ) ) -@project_router.post("/list") +@project_router.post("/list", response_model=List[Fleet]) async def list_project_fleets( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> List[Fleet]: +): """ Returns all fleets in the project. """ _, project = user_project - return await fleets_services.list_project_fleets(session=session, project=project) + return CustomORJSONResponse( + await fleets_services.list_project_fleets(session=session, project=project) + ) -@project_router.post("/get") +@project_router.post("/get", response_model=Fleet) async def get_fleet( body: GetFleetRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Fleet: +): """ Returns a fleet given `name` or `id`. If given `name`, does not return deleted fleets. @@ -86,15 +93,15 @@ async def get_fleet( ) if fleet is None: raise ResourceNotExistsError() - return fleet + return CustomORJSONResponse(fleet) -@project_router.post("/get_plan") +@project_router.post("/get_plan", response_model=FleetPlan) async def get_plan( body: GetFleetPlanRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> FleetPlan: +): """ Returns a fleet plan for the given fleet configuration. """ @@ -105,45 +112,49 @@ async def get_plan( user=user, spec=body.spec, ) - return plan + return CustomORJSONResponse(plan) -@project_router.post("/apply") +@project_router.post("/apply", response_model=Fleet) async def apply_plan( body: ApplyFleetPlanRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Fleet: +): """ Creates a new fleet or updates an existing fleet. Errors if the expected current resource from the plan does not match the current resource. Use `force: true` to apply even if the current resource does not match. """ user, project = user_project - return await fleets_services.apply_plan( - session=session, - user=user, - project=project, - plan=body.plan, - force=body.force, + return CustomORJSONResponse( + await fleets_services.apply_plan( + session=session, + user=user, + project=project, + plan=body.plan, + force=body.force, + ) ) -@project_router.post("/create") +@project_router.post("/create", response_model=Fleet) async def create_fleet( body: CreateFleetRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Fleet: +): """ Creates a fleet given a fleet configuration. """ user, project = user_project - return await fleets_services.create_fleet( - session=session, - project=project, - user=user, - spec=body.spec, + return CustomORJSONResponse( + await fleets_services.create_fleet( + session=session, + project=project, + user=user, + spec=body.spec, + ) ) diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py index e0e0ad37d..fb03a3d69 100644 --- a/src/dstack/_internal/server/routers/gateways.py +++ b/src/dstack/_internal/server/routers/gateways.py @@ -13,7 +13,10 @@ ProjectAdmin, ProjectMemberOrPublicAccess, ) -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) router = APIRouter( prefix="/api/project/{project_name}/gateways", @@ -22,40 +25,44 @@ ) -@router.post("/list") +@router.post("/list", response_model=List[models.Gateway]) async def list_gateways( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()), -) -> List[models.Gateway]: +): _, project = user_project - return await gateways.list_project_gateways(session=session, project=project) + return CustomORJSONResponse( + await gateways.list_project_gateways(session=session, project=project) + ) -@router.post("/get") +@router.post("/get", response_model=models.Gateway) async def get_gateway( body: schemas.GetGatewayRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()), -) -> models.Gateway: +): _, project = user_project gateway = await gateways.get_gateway_by_name(session=session, project=project, name=body.name) if gateway is None: raise ResourceNotExistsError() - return gateway + return CustomORJSONResponse(gateway) -@router.post("/create") +@router.post("/create", response_model=models.Gateway) async def create_gateway( body: schemas.CreateGatewayRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> models.Gateway: +): user, project = user_project - return await gateways.create_gateway( - session=session, - user=user, - project=project, - configuration=body.configuration, + return CustomORJSONResponse( + await gateways.create_gateway( + session=session, + user=user, + project=project, + configuration=body.configuration, + ) ) @@ -83,13 +90,15 @@ async def set_default_gateway( await gateways.set_default_gateway(session=session, project=project, name=body.name) -@router.post("/set_wildcard_domain") +@router.post("/set_wildcard_domain", response_model=models.Gateway) async def set_gateway_wildcard_domain( body: schemas.SetWildcardDomainRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> models.Gateway: +): _, project = user_project - return await gateways.set_gateway_wildcard_domain( - session=session, project=project, name=body.name, wildcard_domain=body.wildcard_domain + return CustomORJSONResponse( + await gateways.set_gateway_wildcard_domain( + session=session, project=project, name=body.name, wildcard_domain=body.wildcard_domain + ) ) diff --git a/src/dstack/_internal/server/routers/instances.py b/src/dstack/_internal/server/routers/instances.py index 489b3bf1c..740c51fd6 100644 --- a/src/dstack/_internal/server/routers/instances.py +++ b/src/dstack/_internal/server/routers/instances.py @@ -9,7 +9,10 @@ from dstack._internal.server.models import UserModel from dstack._internal.server.schemas.instances import ListInstancesRequest from dstack._internal.server.security.permissions import Authenticated -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) root_router = APIRouter( prefix="/api/instances", @@ -18,12 +21,12 @@ ) -@root_router.post("/list") +@root_router.post("/list", response_model=List[Instance]) async def list_instances( body: ListInstancesRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> List[Instance]: +): """ Returns all instances visible to user sorted by descending `created_at`. `project_names` and `fleet_ids` can be specified as filters. @@ -31,14 +34,16 @@ async def list_instances( The results are paginated. To get the next page, pass `created_at` and `id` of the last instance from the previous page as `prev_created_at` and `prev_id`. """ - return await instances.list_user_instances( - session=session, - user=user, - project_names=body.project_names, - fleet_ids=body.fleet_ids, - only_active=body.only_active, - prev_created_at=body.prev_created_at, - prev_id=body.prev_id, - limit=body.limit, - ascending=body.ascending, + return CustomORJSONResponse( + await instances.list_user_instances( + session=session, + user=user, + project_names=body.project_names, + fleet_ids=body.fleet_ids, + only_active=body.only_active, + prev_created_at=body.prev_created_at, + prev_id=body.prev_id, + limit=body.limit, + ascending=body.ascending, + ) ) diff --git a/src/dstack/_internal/server/routers/logs.py b/src/dstack/_internal/server/routers/logs.py index a86424ee6..29685f6f5 100644 --- a/src/dstack/_internal/server/routers/logs.py +++ b/src/dstack/_internal/server/routers/logs.py @@ -7,7 +7,10 @@ from dstack._internal.server.schemas.logs import PollLogsRequest from dstack._internal.server.security.permissions import ProjectMember from dstack._internal.server.services import logs -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) router = APIRouter( prefix="/api/project/{project_name}/logs", @@ -18,13 +21,14 @@ @router.post( "/poll", + response_model=JobSubmissionLogs, ) async def poll_logs( body: PollLogsRequest, user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> JobSubmissionLogs: +): _, project = user_project # The runner guarantees logs have different timestamps if throughput < 1k logs / sec. # Otherwise, some logs with duplicated timestamps may be filtered out. # This limitation is imposed by cloud log services that support up to millisecond timestamp resolution. - return await logs.poll_logs_async(project=project, request=body) + return CustomORJSONResponse(await logs.poll_logs_async(project=project, request=body)) diff --git a/src/dstack/_internal/server/routers/metrics.py b/src/dstack/_internal/server/routers/metrics.py index 1d4ffb1db..e61a0d9bf 100644 --- a/src/dstack/_internal/server/routers/metrics.py +++ b/src/dstack/_internal/server/routers/metrics.py @@ -11,7 +11,10 @@ from dstack._internal.server.security.permissions import ProjectMember from dstack._internal.server.services import metrics from dstack._internal.server.services.jobs import get_run_job_model -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) router = APIRouter( prefix="/api/project/{project_name}/metrics", @@ -22,6 +25,7 @@ @router.get( "/job/{run_name}", + response_model=JobMetrics, ) async def get_job_metrics( run_name: str, @@ -32,7 +36,7 @@ async def get_job_metrics( before: Optional[datetime] = None, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> JobMetrics: +): """ Returns job-level metrics such as hardware utilization given `run_name`, `replica_num`, and `job_num`. @@ -63,10 +67,12 @@ async def get_job_metrics( if job_model is None: raise ResourceNotExistsError("Found no job with given parameters") - return await metrics.get_job_metrics( - session=session, - job_model=job_model, - limit=limit, - after=after, - before=before, + return CustomORJSONResponse( + await metrics.get_job_metrics( + session=session, + job_model=job_model, + limit=limit, + after=after, + before=before, + ) ) diff --git a/src/dstack/_internal/server/routers/projects.py b/src/dstack/_internal/server/routers/projects.py index 1d967c6c8..56d41b6ca 100644 --- a/src/dstack/_internal/server/routers/projects.py +++ b/src/dstack/_internal/server/routers/projects.py @@ -23,7 +23,10 @@ ProjectMemberOrPublicAccess, ) from dstack._internal.server.services import projects -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) router = APIRouter( prefix="/api/projects", @@ -32,30 +35,34 @@ ) -@router.post("/list") +@router.post("/list", response_model=List[Project]) async def list_projects( session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> List[Project]: +): """ Returns all projects visible to user sorted by descending `created_at`. `members` and `backends` are always empty - call `/api/projects/{project_name}/get` to retrieve them. """ - return await projects.list_user_accessible_projects(session=session, user=user) + return CustomORJSONResponse( + await projects.list_user_accessible_projects(session=session, user=user) + ) -@router.post("/create") +@router.post("/create", response_model=Project) async def create_project( body: CreateProjectRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> Project: - return await projects.create_project( - session=session, - user=user, - project_name=body.project_name, - is_public=body.is_public, +): + return CustomORJSONResponse( + await projects.create_project( + session=session, + user=user, + project_name=body.project_name, + is_public=body.is_public, + ) ) @@ -72,23 +79,24 @@ async def delete_projects( ) -@router.post("/{project_name}/get") +@router.post("/{project_name}/get", response_model=Project) async def get_project( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()), -) -> Project: +): _, project = user_project - return projects.project_model_to_project(project) + return CustomORJSONResponse(projects.project_model_to_project(project)) @router.post( "/{project_name}/set_members", + response_model=Project, ) async def set_project_members( body: SetProjectMembersRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectManager()), -) -> Project: +): user, project = user_project await projects.set_project_members( session=session, @@ -97,17 +105,18 @@ async def set_project_members( members=body.members, ) await session.refresh(project) - return projects.project_model_to_project(project) + return CustomORJSONResponse(projects.project_model_to_project(project)) @router.post( "/{project_name}/add_members", + response_model=Project, ) async def add_project_members( body: AddProjectMemberRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectManagerOrPublicProject()), -) -> Project: +): user, project = user_project await projects.add_project_members( session=session, @@ -116,17 +125,18 @@ async def add_project_members( members=body.members, ) await session.refresh(project) - return projects.project_model_to_project(project) + return CustomORJSONResponse(projects.project_model_to_project(project)) @router.post( "/{project_name}/remove_members", + response_model=Project, ) async def remove_project_members( body: RemoveProjectMemberRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectManagerOrSelfLeave()), -) -> Project: +): user, project = user_project await projects.remove_project_members( session=session, @@ -135,17 +145,18 @@ async def remove_project_members( usernames=body.usernames, ) await session.refresh(project) - return projects.project_model_to_project(project) + return CustomORJSONResponse(projects.project_model_to_project(project)) @router.post( "/{project_name}/update", + response_model=Project, ) async def update_project( body: UpdateProjectRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> Project: +): user, project = user_project await projects.update_project( session=session, @@ -154,4 +165,4 @@ async def update_project( is_public=body.is_public, ) await session.refresh(project) - return projects.project_model_to_project(project) + return CustomORJSONResponse(projects.project_model_to_project(project)) diff --git a/src/dstack/_internal/server/routers/repos.py b/src/dstack/_internal/server/routers/repos.py index 32e59f631..202732f4f 100644 --- a/src/dstack/_internal/server/routers/repos.py +++ b/src/dstack/_internal/server/routers/repos.py @@ -16,6 +16,7 @@ from dstack._internal.server.services import repos from dstack._internal.server.settings import SERVER_CODE_UPLOAD_LIMIT from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, get_base_api_additional_responses, get_request_size, ) @@ -28,21 +29,21 @@ ) -@router.post("/list") +@router.post("/list", response_model=List[RepoHead]) async def list_repos( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> List[RepoHead]: +): _, project = user_project - return await repos.list_repos(session=session, project=project) + return CustomORJSONResponse(await repos.list_repos(session=session, project=project)) -@router.post("/get") +@router.post("/get", response_model=RepoHeadWithCreds) async def get_repo( body: GetRepoRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> RepoHeadWithCreds: +): user, project = user_project repo = await repos.get_repo( session=session, @@ -53,7 +54,7 @@ async def get_repo( ) if repo is None: raise ResourceNotExistsError() - return repo + return CustomORJSONResponse(repo) @router.post("/init") diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index c6a4b60f8..8f3909503 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -18,7 +18,10 @@ ) from dstack._internal.server.security.permissions import Authenticated, ProjectMember from dstack._internal.server.services import runs -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) root_router = APIRouter( prefix="/api/runs", @@ -32,12 +35,15 @@ ) -@root_router.post("/list") +@root_router.post( + "/list", + response_model=List[Run], +) async def list_runs( body: ListRunsRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> List[Run]: +): """ Returns all runs visible to user sorted by descending `submitted_at`. `project_name`, `repo_id`, `username`, and `only_active` can be specified as filters. @@ -47,26 +53,33 @@ async def list_runs( The results are paginated. To get the next page, pass `submitted_at` and `id` of the last run from the previous page as `prev_submitted_at` and `prev_run_id`. """ - return await runs.list_user_runs( - session=session, - user=user, - project_name=body.project_name, - repo_id=body.repo_id, - username=body.username, - only_active=body.only_active, - prev_submitted_at=body.prev_submitted_at, - prev_run_id=body.prev_run_id, - limit=body.limit, - ascending=body.ascending, + return CustomORJSONResponse( + await runs.list_user_runs( + session=session, + user=user, + project_name=body.project_name, + repo_id=body.repo_id, + username=body.username, + only_active=body.only_active, + include_jobs=body.include_jobs, + job_submissions_limit=body.job_submissions_limit, + prev_submitted_at=body.prev_submitted_at, + prev_run_id=body.prev_run_id, + limit=body.limit, + ascending=body.ascending, + ) ) -@project_router.post("/get") +@project_router.post( + "/get", + response_model=Run, +) async def get_run( body: GetRunRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Run: +): """ Returns a run given `run_name` or `id`. If given `run_name`, does not return deleted runs. @@ -81,15 +94,18 @@ async def get_run( ) if run is None: raise ResourceNotExistsError("Run not found") - return run + return CustomORJSONResponse(run) -@project_router.post("/get_plan") +@project_router.post( + "/get_plan", + response_model=RunPlan, +) async def get_plan( body: GetRunPlanRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> RunPlan: +): """ Returns a run plan for the given run spec. This is an optional step before calling `/apply`. @@ -102,15 +118,18 @@ async def get_plan( run_spec=body.run_spec, max_offers=body.max_offers, ) - return run_plan + return CustomORJSONResponse(run_plan) -@project_router.post("/apply") +@project_router.post( + "/apply", + response_model=Run, +) async def apply_plan( body: ApplyRunPlanRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Run: +): """ Creates a new run or updates an existing run. Errors if the expected current resource from the plan does not match the current resource. @@ -118,12 +137,14 @@ async def apply_plan( If the existing run is active and cannot be updated, it must be stopped first. """ user, project = user_project - return await runs.apply_plan( - session=session, - user=user, - project=project, - plan=body.plan, - force=body.force, + return CustomORJSONResponse( + await runs.apply_plan( + session=session, + user=user, + project=project, + plan=body.plan, + force=body.force, + ) ) diff --git a/src/dstack/_internal/server/routers/secrets.py b/src/dstack/_internal/server/routers/secrets.py index bbfa26be9..c19f15bcc 100644 --- a/src/dstack/_internal/server/routers/secrets.py +++ b/src/dstack/_internal/server/routers/secrets.py @@ -14,6 +14,7 @@ ) from dstack._internal.server.security.permissions import ProjectAdmin from dstack._internal.server.services import secrets as secrets_services +from dstack._internal.server.utils.routers import CustomORJSONResponse router = APIRouter( prefix="/api/project/{project_name}/secrets", @@ -21,24 +22,26 @@ ) -@router.post("/list") +@router.post("/list", response_model=List[Secret]) async def list_secrets( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> List[Secret]: +): _, project = user_project - return await secrets_services.list_secrets( - session=session, - project=project, + return CustomORJSONResponse( + await secrets_services.list_secrets( + session=session, + project=project, + ) ) -@router.post("/get") +@router.post("/get", response_model=Secret) async def get_secret( body: GetSecretRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> Secret: +): _, project = user_project secret = await secrets_services.get_secret( session=session, @@ -47,21 +50,23 @@ async def get_secret( ) if secret is None: raise ResourceNotExistsError() - return secret + return CustomORJSONResponse(secret) -@router.post("/create_or_update") +@router.post("/create_or_update", response_model=Secret) async def create_or_update_secret( body: CreateOrUpdateSecretRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> Secret: +): _, project = user_project - return await secrets_services.create_or_update_secret( - session=session, - project=project, - name=body.name, - value=body.value, + return CustomORJSONResponse( + await secrets_services.create_or_update_secret( + session=session, + project=project, + name=body.name, + value=body.value, + ) ) diff --git a/src/dstack/_internal/server/routers/server.py b/src/dstack/_internal/server/routers/server.py index 31e1e04c9..28c742772 100644 --- a/src/dstack/_internal/server/routers/server.py +++ b/src/dstack/_internal/server/routers/server.py @@ -2,6 +2,7 @@ from dstack._internal import settings from dstack._internal.core.models.server import ServerInfo +from dstack._internal.server.utils.routers import CustomORJSONResponse router = APIRouter( prefix="/api/server", @@ -9,8 +10,10 @@ ) -@router.post("/get_info") -async def get_server_info() -> ServerInfo: - return ServerInfo( - server_version=settings.DSTACK_VERSION, +@router.post("/get_info", response_model=ServerInfo) +async def get_server_info(): + return CustomORJSONResponse( + ServerInfo( + server_version=settings.DSTACK_VERSION, + ) ) diff --git a/src/dstack/_internal/server/routers/users.py b/src/dstack/_internal/server/routers/users.py index 670f9f0a5..abb672914 100644 --- a/src/dstack/_internal/server/routers/users.py +++ b/src/dstack/_internal/server/routers/users.py @@ -16,7 +16,10 @@ ) from dstack._internal.server.security.permissions import Authenticated, GlobalAdmin from dstack._internal.server.services import users -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) router = APIRouter( prefix="/api/users", @@ -25,41 +28,41 @@ ) -@router.post("/list") +@router.post("/list", response_model=List[User]) async def list_users( session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> List[User]: - return await users.list_users_for_user(session=session, user=user) +): + return CustomORJSONResponse(await users.list_users_for_user(session=session, user=user)) -@router.post("/get_my_user") +@router.post("/get_my_user", response_model=User) async def get_my_user( user: UserModel = Depends(Authenticated()), -) -> User: - return users.user_model_to_user(user) +): + return CustomORJSONResponse(users.user_model_to_user(user)) -@router.post("/get_user") +@router.post("/get_user", response_model=UserWithCreds) async def get_user( body: GetUserRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> UserWithCreds: +): res = await users.get_user_with_creds_by_name( session=session, current_user=user, username=body.username ) if res is None: raise ResourceNotExistsError() - return res + return CustomORJSONResponse(res) -@router.post("/create") +@router.post("/create", response_model=User) async def create_user( body: CreateUserRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(GlobalAdmin()), -) -> User: +): res = await users.create_user( session=session, username=body.username, @@ -67,15 +70,15 @@ async def create_user( email=body.email, active=body.active, ) - return users.user_model_to_user(res) + return CustomORJSONResponse(users.user_model_to_user(res)) -@router.post("/update") +@router.post("/update", response_model=User) async def update_user( body: UpdateUserRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(GlobalAdmin()), -) -> User: +): res = await users.update_user( session=session, username=body.username, @@ -85,19 +88,19 @@ async def update_user( ) if res is None: raise ResourceNotExistsError() - return users.user_model_to_user(res) + return CustomORJSONResponse(users.user_model_to_user(res)) -@router.post("/refresh_token") +@router.post("/refresh_token", response_model=UserWithCreds) async def refresh_token( body: RefreshTokenRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> UserWithCreds: +): res = await users.refresh_user_token(session=session, user=user, username=body.username) if res is None: raise ResourceNotExistsError() - return users.user_model_to_user_with_creds(res) + return CustomORJSONResponse(users.user_model_to_user_with_creds(res)) @router.post("/delete") diff --git a/src/dstack/_internal/server/routers/volumes.py b/src/dstack/_internal/server/routers/volumes.py index d4137099f..2ac503470 100644 --- a/src/dstack/_internal/server/routers/volumes.py +++ b/src/dstack/_internal/server/routers/volumes.py @@ -15,7 +15,10 @@ ListVolumesRequest, ) from dstack._internal.server.security.permissions import Authenticated, ProjectMember -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) root_router = APIRouter( prefix="/api/volumes", @@ -25,12 +28,12 @@ project_router = APIRouter(prefix="/api/project/{project_name}/volumes", tags=["volumes"]) -@root_router.post("/list") +@root_router.post("/list", response_model=List[Volume]) async def list_volumes( body: ListVolumesRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> List[Volume]: +): """ Returns all volumes visible to user sorted by descending `created_at`. `project_name` and `only_active` can be specified as filters. @@ -38,36 +41,40 @@ async def list_volumes( The results are paginated. To get the next page, pass `created_at` and `id` of the last fleet from the previous page as `prev_created_at` and `prev_id`. """ - return await volumes_services.list_volumes( - session=session, - user=user, - project_name=body.project_name, - only_active=body.only_active, - prev_created_at=body.prev_created_at, - prev_id=body.prev_id, - limit=body.limit, - ascending=body.ascending, + return CustomORJSONResponse( + await volumes_services.list_volumes( + session=session, + user=user, + project_name=body.project_name, + only_active=body.only_active, + prev_created_at=body.prev_created_at, + prev_id=body.prev_id, + limit=body.limit, + ascending=body.ascending, + ) ) -@project_router.post("/list") +@project_router.post("/list", response_model=List[Volume]) async def list_project_volumes( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> List[Volume]: +): """ Returns all volumes in the project. """ _, project = user_project - return await volumes_services.list_project_volumes(session=session, project=project) + return CustomORJSONResponse( + await volumes_services.list_project_volumes(session=session, project=project) + ) -@project_router.post("/get") +@project_router.post("/get", response_model=Volume) async def get_volume( body: GetVolumeRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Volume: +): """ Returns a volume given a volume name. """ @@ -77,24 +84,26 @@ async def get_volume( ) if volume is None: raise ResourceNotExistsError() - return volume + return CustomORJSONResponse(volume) -@project_router.post("/create") +@project_router.post("/create", response_model=Volume) async def create_volume( body: CreateVolumeRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Volume: +): """ Creates a volume given a volume configuration. """ user, project = user_project - return await volumes_services.create_volume( - session=session, - project=project, - user=user, - configuration=body.configuration, + return CustomORJSONResponse( + await volumes_services.create_volume( + session=session, + project=project, + user=user, + configuration=body.configuration, + ) ) diff --git a/src/dstack/_internal/server/schemas/logs.py b/src/dstack/_internal/server/schemas/logs.py index 267f5612f..f97d4fde3 100644 --- a/src/dstack/_internal/server/schemas/logs.py +++ b/src/dstack/_internal/server/schemas/logs.py @@ -9,8 +9,8 @@ class PollLogsRequest(CoreModel): run_name: str job_submission_id: UUID4 - start_time: Optional[datetime] - end_time: Optional[datetime] + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None descending: bool = False next_token: Optional[str] = None limit: int = Field(100, ge=0, le=1000) diff --git a/src/dstack/_internal/server/schemas/runs.py b/src/dstack/_internal/server/schemas/runs.py index 8ae875df0..844724371 100644 --- a/src/dstack/_internal/server/schemas/runs.py +++ b/src/dstack/_internal/server/schemas/runs.py @@ -9,12 +9,24 @@ class ListRunsRequest(CoreModel): - project_name: Optional[str] - repo_id: Optional[str] - username: Optional[str] + project_name: Optional[str] = None + repo_id: Optional[str] = None + username: Optional[str] = None only_active: bool = False - prev_submitted_at: Optional[datetime] - prev_run_id: Optional[UUID] + include_jobs: bool = Field( + True, + description=("Whether to include `jobs` in the response"), + ) + job_submissions_limit: Optional[int] = Field( + None, + ge=0, + description=( + "Limit number of job submissions returned per job to avoid large responses." + "Drops older job submissions. No effect with `include_jobs: false`" + ), + ) + prev_submitted_at: Optional[datetime] = None + prev_run_id: Optional[UUID] = None limit: int = Field(100, ge=0, le=100) ascending: bool = False diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 9925483e4..93ed23b86 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -1,6 +1,8 @@ import uuid +from collections.abc import Callable from datetime import datetime, timezone -from typing import List, Literal, Optional, Tuple, Union, cast +from functools import wraps +from typing import List, Literal, Optional, Tuple, TypeVar, Union, cast from sqlalchemy import and_, func, or_, select from sqlalchemy.ext.asyncio import AsyncSession @@ -13,10 +15,12 @@ ResourceExistsError, ServerClientError, ) +from dstack._internal.core.models.common import ApplyAction, CoreModel from dstack._internal.core.models.envs import Env from dstack._internal.core.models.fleets import ( ApplyFleetPlanInput, Fleet, + FleetConfiguration, FleetPlan, FleetSpec, FleetStatus, @@ -40,6 +44,7 @@ from dstack._internal.core.models.runs import Requirements, get_policy_map from dstack._internal.core.models.users import GlobalRole from dstack._internal.core.services import validate_dstack_resource_name +from dstack._internal.core.services.diff import ModelDiff, copy_model, diff_models from dstack._internal.server.db import get_db from dstack._internal.server.models import ( FleetModel, @@ -49,7 +54,10 @@ ) from dstack._internal.server.services import instances as instances_services from dstack._internal.server.services import offers as offers_services -from dstack._internal.server.services.instances import list_active_remote_instances +from dstack._internal.server.services.instances import ( + get_instance_remote_connection_info, + list_active_remote_instances, +) from dstack._internal.server.services.locking import ( get_locker, string_to_lock_id, @@ -178,8 +186,9 @@ async def list_project_fleet_models( async def get_fleet( session: AsyncSession, project: ProjectModel, - name: Optional[str], - fleet_id: Optional[uuid.UUID], + name: Optional[str] = None, + fleet_id: Optional[uuid.UUID] = None, + include_sensitive: bool = False, ) -> Optional[Fleet]: if fleet_id is not None: fleet_model = await get_project_fleet_model_by_id( @@ -193,7 +202,7 @@ async def get_fleet( raise ServerClientError("name or id must be specified") if fleet_model is None: return None - return fleet_model_to_fleet(fleet_model) + return fleet_model_to_fleet(fleet_model, include_sensitive=include_sensitive) async def get_project_fleet_model_by_id( @@ -236,23 +245,32 @@ async def get_plan( spec: FleetSpec, ) -> FleetPlan: # Spec must be copied by parsing to calculate merged_profile - effective_spec = FleetSpec.parse_obj(spec.dict()) + effective_spec = copy_model(spec) effective_spec = await apply_plugin_policies( user=user.name, project=project.name, spec=effective_spec, ) - effective_spec = FleetSpec.parse_obj(effective_spec.dict()) - _validate_fleet_spec_and_set_defaults(spec) + # Spec must be copied by parsing to calculate merged_profile + effective_spec = copy_model(effective_spec) + _validate_fleet_spec_and_set_defaults(effective_spec) + + action = ApplyAction.CREATE current_fleet: Optional[Fleet] = None current_fleet_id: Optional[uuid.UUID] = None + if effective_spec.configuration.name is not None: - current_fleet_model = await get_project_fleet_model_by_name( - session=session, project=project, name=effective_spec.configuration.name + current_fleet = await get_fleet( + session=session, + project=project, + name=effective_spec.configuration.name, + include_sensitive=True, ) - if current_fleet_model is not None: - current_fleet = fleet_model_to_fleet(current_fleet_model) - current_fleet_id = current_fleet_model.id + if current_fleet is not None: + _set_fleet_spec_defaults(current_fleet.spec) + if _can_update_fleet_spec(current_fleet.spec, effective_spec): + action = ApplyAction.UPDATE + current_fleet_id = current_fleet.id await _check_ssh_hosts_not_yet_added(session, effective_spec, current_fleet_id) offers = [] @@ -265,7 +283,10 @@ async def get_plan( blocks=effective_spec.configuration.blocks, ) offers = [offer for _, offer in offers_with_backends] + _remove_fleet_spec_sensitive_info(effective_spec) + if current_fleet is not None: + _remove_fleet_spec_sensitive_info(current_fleet.spec) plan = FleetPlan( project_name=project.name, user=user.name, @@ -275,6 +296,7 @@ async def get_plan( offers=offers[:50], total_offers=len(offers), max_offer_price=max((offer.price for offer in offers), default=None), + action=action, ) return plan @@ -327,11 +349,77 @@ async def apply_plan( plan: ApplyFleetPlanInput, force: bool, ) -> Fleet: - return await create_fleet( + spec = await apply_plugin_policies( + user=user.name, + project=project.name, + spec=plan.spec, + ) + # Spec must be copied by parsing to calculate merged_profile + spec = copy_model(spec) + _validate_fleet_spec_and_set_defaults(spec) + + if spec.configuration.ssh_config is not None: + _check_can_manage_ssh_fleets(user=user, project=project) + + configuration = spec.configuration + if configuration.name is None: + return await _create_fleet( + session=session, + project=project, + user=user, + spec=spec, + ) + + fleet_model = await get_project_fleet_model_by_name( + session=session, + project=project, + name=configuration.name, + ) + if fleet_model is None: + return await _create_fleet( + session=session, + project=project, + user=user, + spec=spec, + ) + + instances_ids = sorted(i.id for i in fleet_model.instances if not i.deleted) + await session.commit() + async with ( + get_locker(get_db().dialect_name).lock_ctx(FleetModel.__tablename__, [fleet_model.id]), + get_locker(get_db().dialect_name).lock_ctx(InstanceModel.__tablename__, instances_ids), + ): + # Refetch after lock + # TODO: Lock instances with FOR UPDATE? + res = await session.execute( + select(FleetModel) + .where( + FleetModel.project_id == project.id, + FleetModel.id == fleet_model.id, + FleetModel.deleted == False, + ) + .options(selectinload(FleetModel.instances)) + .options(selectinload(FleetModel.runs)) + .execution_options(populate_existing=True) + .order_by(FleetModel.id) # take locks in order + .with_for_update(key_share=True) + ) + fleet_model = res.scalars().unique().one_or_none() + if fleet_model is not None: + return await _update_fleet( + session=session, + project=project, + spec=spec, + current_resource=plan.current_resource, + force=force, + fleet_model=fleet_model, + ) + + return await _create_fleet( session=session, project=project, user=user, - spec=plan.spec, + spec=spec, ) @@ -341,73 +429,19 @@ async def create_fleet( user: UserModel, spec: FleetSpec, ) -> Fleet: - # Spec must be copied by parsing to calculate merged_profile spec = await apply_plugin_policies( user=user.name, project=project.name, spec=spec, ) - spec = FleetSpec.parse_obj(spec.dict()) + # Spec must be copied by parsing to calculate merged_profile + spec = copy_model(spec) _validate_fleet_spec_and_set_defaults(spec) if spec.configuration.ssh_config is not None: _check_can_manage_ssh_fleets(user=user, project=project) - lock_namespace = f"fleet_names_{project.name}" - if get_db().dialect_name == "sqlite": - # Start new transaction to see committed changes after lock - await session.commit() - elif get_db().dialect_name == "postgresql": - await session.execute( - select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace))) - ) - - lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace) - async with lock: - if spec.configuration.name is not None: - fleet_model = await get_project_fleet_model_by_name( - session=session, - project=project, - name=spec.configuration.name, - ) - if fleet_model is not None: - raise ResourceExistsError() - else: - spec.configuration.name = await generate_fleet_name(session=session, project=project) - - fleet_model = FleetModel( - id=uuid.uuid4(), - name=spec.configuration.name, - project=project, - status=FleetStatus.ACTIVE, - spec=spec.json(), - instances=[], - ) - session.add(fleet_model) - if spec.configuration.ssh_config is not None: - for i, host in enumerate(spec.configuration.ssh_config.hosts): - instances_model = await create_fleet_ssh_instance_model( - project=project, - spec=spec, - ssh_params=spec.configuration.ssh_config, - env=spec.configuration.env, - instance_num=i, - host=host, - ) - fleet_model.instances.append(instances_model) - else: - for i in range(_get_fleet_nodes_to_provision(spec)): - instance_model = await create_fleet_instance_model( - session=session, - project=project, - user=user, - spec=spec, - reservation=spec.configuration.reservation, - instance_num=i, - ) - fleet_model.instances.append(instance_model) - await session.commit() - return fleet_model_to_fleet(fleet_model) + return await _create_fleet(session=session, project=project, user=user, spec=spec) async def create_fleet_instance_model( @@ -600,6 +634,235 @@ def is_fleet_empty(fleet_model: FleetModel) -> bool: return len(active_instances) == 0 +async def _create_fleet( + session: AsyncSession, + project: ProjectModel, + user: UserModel, + spec: FleetSpec, +) -> Fleet: + lock_namespace = f"fleet_names_{project.name}" + if get_db().dialect_name == "sqlite": + # Start new transaction to see committed changes after lock + await session.commit() + elif get_db().dialect_name == "postgresql": + await session.execute( + select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace))) + ) + + lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace) + async with lock: + if spec.configuration.name is not None: + fleet_model = await get_project_fleet_model_by_name( + session=session, + project=project, + name=spec.configuration.name, + ) + if fleet_model is not None: + raise ResourceExistsError() + else: + spec.configuration.name = await generate_fleet_name(session=session, project=project) + + fleet_model = FleetModel( + id=uuid.uuid4(), + name=spec.configuration.name, + project=project, + status=FleetStatus.ACTIVE, + spec=spec.json(), + instances=[], + ) + session.add(fleet_model) + if spec.configuration.ssh_config is not None: + for i, host in enumerate(spec.configuration.ssh_config.hosts): + instances_model = await create_fleet_ssh_instance_model( + project=project, + spec=spec, + ssh_params=spec.configuration.ssh_config, + env=spec.configuration.env, + instance_num=i, + host=host, + ) + fleet_model.instances.append(instances_model) + else: + for i in range(_get_fleet_nodes_to_provision(spec)): + instance_model = await create_fleet_instance_model( + session=session, + project=project, + user=user, + spec=spec, + reservation=spec.configuration.reservation, + instance_num=i, + ) + fleet_model.instances.append(instance_model) + await session.commit() + return fleet_model_to_fleet(fleet_model) + + +async def _update_fleet( + session: AsyncSession, + project: ProjectModel, + spec: FleetSpec, + current_resource: Optional[Fleet], + force: bool, + fleet_model: FleetModel, +) -> Fleet: + fleet = fleet_model_to_fleet(fleet_model) + _set_fleet_spec_defaults(fleet.spec) + fleet_sensitive = fleet_model_to_fleet(fleet_model, include_sensitive=True) + _set_fleet_spec_defaults(fleet_sensitive.spec) + + if not force: + if current_resource is not None: + _set_fleet_spec_defaults(current_resource.spec) + if ( + current_resource is None + or current_resource.id != fleet.id + or current_resource.spec != fleet.spec + ): + raise ServerClientError( + "Failed to apply plan. Resource has been changed. Try again or use force apply." + ) + + _check_can_update_fleet_spec(fleet_sensitive.spec, spec) + + spec_json = spec.json() + fleet_model.spec = spec_json + + if ( + fleet_sensitive.spec.configuration.ssh_config is not None + and spec.configuration.ssh_config is not None + ): + added_hosts, removed_hosts, changed_hosts = _calculate_ssh_hosts_changes( + current=fleet_sensitive.spec.configuration.ssh_config.hosts, + new=spec.configuration.ssh_config.hosts, + ) + # `_check_can_update_fleet_spec` ensures hosts are not changed + assert not changed_hosts, changed_hosts + active_instance_nums: set[int] = set() + removed_instance_nums: list[int] = [] + if removed_hosts or added_hosts: + for instance_model in fleet_model.instances: + if instance_model.deleted: + continue + active_instance_nums.add(instance_model.instance_num) + rci = get_instance_remote_connection_info(instance_model) + if rci is None: + logger.error( + "Cloud instance %s in SSH fleet %s", + instance_model.id, + fleet_model.id, + ) + continue + if rci.host in removed_hosts: + removed_instance_nums.append(instance_model.instance_num) + if added_hosts: + await _check_ssh_hosts_not_yet_added(session, spec, fleet.id) + for host in added_hosts.values(): + instance_num = _get_next_instance_num(active_instance_nums) + instance_model = await create_fleet_ssh_instance_model( + project=project, + spec=spec, + ssh_params=spec.configuration.ssh_config, + env=spec.configuration.env, + instance_num=instance_num, + host=host, + ) + fleet_model.instances.append(instance_model) + active_instance_nums.add(instance_num) + if removed_instance_nums: + _terminate_fleet_instances(fleet_model, removed_instance_nums) + + await session.commit() + return fleet_model_to_fleet(fleet_model) + + +def _can_update_fleet_spec(current_fleet_spec: FleetSpec, new_fleet_spec: FleetSpec) -> bool: + try: + _check_can_update_fleet_spec(current_fleet_spec, new_fleet_spec) + except ServerClientError as e: + logger.debug("Run cannot be updated: %s", repr(e)) + return False + return True + + +M = TypeVar("M", bound=CoreModel) + + +def _check_can_update(*updatable_fields: str): + def decorator(fn: Callable[[M, M, ModelDiff], None]) -> Callable[[M, M], None]: + @wraps(fn) + def inner(current: M, new: M): + diff = _check_can_update_inner(current, new, updatable_fields) + fn(current, new, diff) + + return inner + + return decorator + + +def _check_can_update_inner(current: M, new: M, updatable_fields: tuple[str, ...]) -> ModelDiff: + diff = diff_models(current, new) + changed_fields = diff.keys() + if not (changed_fields <= set(updatable_fields)): + raise ServerClientError( + f"Failed to update fields {list(changed_fields)}." + f" Can only update {list(updatable_fields)}." + ) + return diff + + +@_check_can_update("configuration", "configuration_path") +def _check_can_update_fleet_spec(current: FleetSpec, new: FleetSpec, diff: ModelDiff): + if "configuration" in diff: + _check_can_update_fleet_configuration(current.configuration, new.configuration) + + +@_check_can_update("ssh_config") +def _check_can_update_fleet_configuration( + current: FleetConfiguration, new: FleetConfiguration, diff: ModelDiff +): + if "ssh_config" in diff: + current_ssh_config = current.ssh_config + new_ssh_config = new.ssh_config + if current_ssh_config is None: + if new_ssh_config is not None: + raise ServerClientError("Fleet type changed from Cloud to SSH, cannot update") + elif new_ssh_config is None: + raise ServerClientError("Fleet type changed from SSH to Cloud, cannot update") + else: + _check_can_update_ssh_config(current_ssh_config, new_ssh_config) + + +@_check_can_update("hosts") +def _check_can_update_ssh_config(current: SSHParams, new: SSHParams, diff: ModelDiff): + if "hosts" in diff: + _, _, changed_hosts = _calculate_ssh_hosts_changes(current.hosts, new.hosts) + if changed_hosts: + raise ServerClientError( + f"Hosts configuration changed, cannot update: {list(changed_hosts)}" + ) + + +def _calculate_ssh_hosts_changes( + current: list[Union[SSHHostParams, str]], new: list[Union[SSHHostParams, str]] +) -> tuple[dict[str, Union[SSHHostParams, str]], set[str], set[str]]: + current_hosts = {h if isinstance(h, str) else h.hostname: h for h in current} + new_hosts = {h if isinstance(h, str) else h.hostname: h for h in new} + added_hosts = {h: new_hosts[h] for h in new_hosts.keys() - current_hosts} + removed_hosts = current_hosts.keys() - new_hosts + changed_hosts: set[str] = set() + for host in current_hosts.keys() & new_hosts: + current_host = current_hosts[host] + new_host = new_hosts[host] + if isinstance(current_host, str) or isinstance(new_host, str): + if current_host != new_host: + changed_hosts.add(host) + elif diff_models( + current_host, new_host, reset={"identity_file": True, "proxy_jump": {"identity_file"}} + ): + changed_hosts.add(host) + return added_hosts, removed_hosts, changed_hosts + + def _check_can_manage_ssh_fleets(user: UserModel, project: ProjectModel): if user.global_role == GlobalRole.ADMIN: return @@ -654,6 +917,8 @@ def _validate_fleet_spec_and_set_defaults(spec: FleetSpec): validate_dstack_resource_name(spec.configuration.name) if spec.configuration.ssh_config is None and spec.configuration.nodes is None: raise ServerClientError("No ssh_config or nodes specified") + if spec.configuration.ssh_config is not None and spec.configuration.nodes is not None: + raise ServerClientError("ssh_config and nodes are mutually exclusive") if spec.configuration.ssh_config is not None: _validate_all_ssh_params_specified(spec.configuration.ssh_config) if spec.configuration.ssh_config.ssh_key is not None: @@ -662,6 +927,10 @@ def _validate_fleet_spec_and_set_defaults(spec: FleetSpec): if isinstance(host, SSHHostParams) and host.ssh_key is not None: _validate_ssh_key(host.ssh_key) _validate_internal_ips(spec.configuration.ssh_config) + _set_fleet_spec_defaults(spec) + + +def _set_fleet_spec_defaults(spec: FleetSpec): if spec.configuration.resources is not None: set_resources_defaults(spec.configuration.resources) @@ -734,3 +1003,16 @@ def _get_fleet_requirements(fleet_spec: FleetSpec) -> Requirements: reservation=fleet_spec.configuration.reservation, ) return requirements + + +def _get_next_instance_num(instance_nums: set[int]) -> int: + if not instance_nums: + return 0 + min_instance_num = min(instance_nums) + if min_instance_num > 0: + return 0 + instance_num = min_instance_num + 1 + while True: + if instance_num not in instance_nums: + return instance_num + instance_num += 1 diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 156459257..fb3839316 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -2,6 +2,7 @@ import datetime import uuid from datetime import timedelta, timezone +from functools import partial from typing import List, Optional, Sequence import httpx @@ -186,6 +187,7 @@ async def create_gateway( return gateway_model_to_gateway(gateway) +# NOTE: dstack Sky imports and uses this function async def connect_to_gateway_with_retry( gateway_compute: GatewayComputeModel, ) -> Optional[GatewayConnection]: @@ -380,6 +382,8 @@ async def get_or_add_gateway_connection( async def init_gateways(session: AsyncSession): res = await session.execute( select(GatewayComputeModel).where( + # FIXME: should not include computes related to gateways in the `provisioning` status. + # Causes warnings and delays when restarting the server during gateway provisioning. GatewayComputeModel.active == True, GatewayComputeModel.deleted == False, ) @@ -421,7 +425,8 @@ async def init_gateways(session: AsyncSession): for gateway_compute, error in await gather_map_async( await gateway_connections_pool.all(), - configure_gateway, + # Need several attempts to handle short gateway downtime after update + partial(configure_gateway, attempts=7), return_exceptions=True, ): if isinstance(error, Exception): @@ -461,7 +466,11 @@ def _recently_updated(gateway_compute_model: GatewayComputeModel) -> bool: ) > get_current_datetime() - timedelta(seconds=60) -async def configure_gateway(connection: GatewayConnection) -> None: +# NOTE: dstack Sky imports and uses this function +async def configure_gateway( + connection: GatewayConnection, + attempts: int = GATEWAY_CONFIGURE_ATTEMPTS, +) -> None: """ Try submitting gateway config several times in case gateway's HTTP server is not running yet @@ -469,7 +478,7 @@ async def configure_gateway(connection: GatewayConnection) -> None: logger.debug("Configuring gateway %s", connection.ip_address) - for attempt in range(GATEWAY_CONFIGURE_ATTEMPTS - 1): + for attempt in range(attempts - 1): try: async with connection.client() as client: await client.submit_gateway_config() @@ -478,7 +487,7 @@ async def configure_gateway(connection: GatewayConnection) -> None: logger.debug( "Failed attempt %s/%s at configuring gateway %s: %r", attempt + 1, - GATEWAY_CONFIGURE_ATTEMPTS, + attempts, connection.ip_address, e, ) diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py index ed87bc84f..aa4b4823c 100644 --- a/src/dstack/_internal/server/services/gateways/client.py +++ b/src/dstack/_internal/server/services/gateways/client.py @@ -7,9 +7,9 @@ from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT from dstack._internal.core.errors import GatewayError -from dstack._internal.core.models.configurations import RateLimit +from dstack._internal.core.models.configurations import RateLimit, ServiceConfiguration from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.runs import JobSubmission, Run +from dstack._internal.core.models.runs import JobSpec, JobSubmission, Run, get_service_port from dstack._internal.proxy.gateway.schemas.stats import ServiceStats from dstack._internal.server import settings @@ -80,13 +80,15 @@ async def unregister_service(self, project: str, run_name: str): async def register_replica( self, run: Run, + job_spec: JobSpec, job_submission: JobSubmission, ssh_head_proxy: Optional[SSHConnectionParams], ssh_head_proxy_private_key: Optional[str], ): + assert isinstance(run.run_spec.configuration, ServiceConfiguration) payload = { "job_id": job_submission.id.hex, - "app_port": run.run_spec.configuration.port.container_port, + "app_port": get_service_port(job_spec, run.run_spec.configuration), "ssh_head_proxy": ssh_head_proxy.dict() if ssh_head_proxy is not None else None, "ssh_head_proxy_private_key": ssh_head_proxy_private_key, } diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index 53ca55396..c2dd5c2a1 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -106,6 +106,14 @@ def get_instance_requirements(instance_model: InstanceModel) -> Requirements: return Requirements.__response__.parse_raw(instance_model.requirements) +def get_instance_remote_connection_info( + instance_model: InstanceModel, +) -> Optional[RemoteConnectionInfo]: + if instance_model.remote_connection_info is None: + return None + return RemoteConnectionInfo.__response__.parse_raw(instance_model.remote_connection_info) + + def get_instance_ssh_private_keys(instance_model: InstanceModel) -> tuple[str, Optional[str]]: """ Returns a pair of SSH private keys: host key and optional proxy jump key. diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 41aa496be..8343f4069 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -134,6 +134,8 @@ def job_model_to_job_submission(job_model: JobModel) -> JobSubmission: finished_at = None if job_model.status.is_finished(): finished_at = last_processed_at + status_message = _get_job_status_message(job_model) + error = _get_job_error(job_model) return JobSubmission( id=job_model.id, submission_num=job_model.submission_num, @@ -143,11 +145,13 @@ def job_model_to_job_submission(job_model: JobModel) -> JobSubmission: finished_at=finished_at, inactivity_secs=job_model.inactivity_secs, status=job_model.status, + status_message=status_message, termination_reason=job_model.termination_reason, termination_reason_message=job_model.termination_reason_message, exit_status=job_model.exit_status, job_provisioning_data=job_provisioning_data, job_runtime_data=get_job_runtime_data(job_model), + error=error, ) @@ -289,6 +293,19 @@ async def process_terminating_job( # so that stuck volumes don't prevent the instance from terminating. job_model.instance_id = None instance_model.last_job_processed_at = common.get_current_datetime() + + volume_names = ( + jrd.volume_names + if jrd and jrd.volume_names + else [va.volume.name for va in instance_model.volume_attachments] + ) + if volume_names: + volumes = await list_project_volume_models( + session=session, project=instance_model.project, names=volume_names + ) + for volume in volumes: + volume.last_job_processed_at = common.get_current_datetime() + logger.info( "%s: instance '%s' has been released, new status is %s", fmt(job_model), @@ -693,3 +710,31 @@ def _get_job_mount_point_attached_volume( continue return volume raise ServerClientError("Failed to find an eligible volume for the mount point") + + +def _get_job_status_message(job_model: JobModel) -> str: + if job_model.status == JobStatus.DONE: + return "exited (0)" + elif job_model.status == JobStatus.FAILED: + if job_model.termination_reason == JobTerminationReason.CONTAINER_EXITED_WITH_ERROR: + return f"exited ({job_model.exit_status})" + elif ( + job_model.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY + ): + return "no offers" + elif job_model.termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY: + return "interrupted" + else: + return "error" + elif job_model.status == JobStatus.TERMINATED: + if job_model.termination_reason == JobTerminationReason.TERMINATED_BY_USER: + return "stopped" + elif job_model.termination_reason == JobTerminationReason.ABORTED_BY_USER: + return "aborted" + return job_model.status.value + + +def _get_job_error(job_model: JobModel) -> Optional[str]: + if job_model.termination_reason is None: + return None + return job_model.termination_reason.to_error() diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index 079e47f0b..e1fcee597 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -15,6 +15,7 @@ PortMapping, PythonVersion, RunConfigurationType, + ServiceConfiguration, ) from dstack._internal.core.models.profiles import ( DEFAULT_STOP_DURATION, @@ -153,6 +154,7 @@ async def _get_job_spec( repo_data=self.run_spec.repo_data, repo_code_hash=self.run_spec.repo_code_hash, file_archives=self.run_spec.file_archives, + service_port=self._service_port(), ) return job_spec @@ -306,6 +308,11 @@ def _ssh_key(self, jobs_per_replica: int) -> Optional[JobSSHKey]: ) return self._job_ssh_key + def _service_port(self) -> Optional[int]: + if isinstance(self.run_spec.configuration, ServiceConfiguration): + return self.run_spec.configuration.port.container_port + return None + def interpolate_job_volumes( run_volumes: List[Union[MountPoint, str]], diff --git a/src/dstack/_internal/server/services/locking.py b/src/dstack/_internal/server/services/locking.py index 37807b37a..4c3b7f938 100644 --- a/src/dstack/_internal/server/services/locking.py +++ b/src/dstack/_internal/server/services/locking.py @@ -172,7 +172,7 @@ async def _wait_to_lock_many( The keys must be sorted to prevent deadlock. """ left_to_lock = keys.copy() - while len(left_to_lock) > 0: + while True: async with lock: locked_now_num = 0 for key in left_to_lock: @@ -182,4 +182,6 @@ async def _wait_to_lock_many( locked.add(key) locked_now_num += 1 left_to_lock = left_to_lock[locked_now_num:] + if not left_to_lock: + return await asyncio.sleep(delay) diff --git a/src/dstack/_internal/server/services/logging.py b/src/dstack/_internal/server/services/logging.py index 1f2d106a5..545067d6a 100644 --- a/src/dstack/_internal/server/services/logging.py +++ b/src/dstack/_internal/server/services/logging.py @@ -1,12 +1,14 @@ from typing import Union -from dstack._internal.server.models import JobModel, RunModel +from dstack._internal.server.models import GatewayModel, JobModel, RunModel -def fmt(model: Union[RunModel, JobModel]) -> str: +def fmt(model: Union[RunModel, JobModel, GatewayModel]) -> str: """Consistent string representation of a model for logging.""" if isinstance(model, RunModel): return f"run({model.id.hex[:6]}){model.run_name}" if isinstance(model, JobModel): return f"job({model.id.hex[:6]}){model.job_name}" + if isinstance(model, GatewayModel): + return f"gateway({model.id.hex[:6]}){model.name}" return str(model) diff --git a/src/dstack/_internal/server/services/logs/__init__.py b/src/dstack/_internal/server/services/logs/__init__.py index a3623a19d..b38264980 100644 --- a/src/dstack/_internal/server/services/logs/__init__.py +++ b/src/dstack/_internal/server/services/logs/__init__.py @@ -8,7 +8,11 @@ from dstack._internal.server.schemas.logs import PollLogsRequest from dstack._internal.server.schemas.runner import LogEvent as RunnerLogEvent from dstack._internal.server.services.logs.aws import BOTO_AVAILABLE, CloudWatchLogStorage -from dstack._internal.server.services.logs.base import LogStorage, LogStorageError +from dstack._internal.server.services.logs.base import ( + LogStorage, + LogStorageError, + b64encode_raw_message, +) from dstack._internal.server.services.logs.filelog import FileLogStorage from dstack._internal.server.services.logs.gcp import GCP_LOGGING_AVAILABLE, GCPLogStorage from dstack._internal.utils.common import run_async @@ -75,4 +79,13 @@ def write_logs( async def poll_logs_async(project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs: - return await run_async(get_log_storage().poll_logs, project=project, request=request) + job_submission_logs = await run_async( + get_log_storage().poll_logs, project=project, request=request + ) + # Logs are stored in plaintext but transmitted in base64 for API/CLI backward compatibility. + # Old logs stored in base64 are encoded twice for transmission and shown as base64 in CLI/UI. + # We live with that. + # TODO: Drop base64 encoding in 0.20. + for log_event in job_submission_logs.logs: + log_event.message = b64encode_raw_message(log_event.message.encode()) + return job_submission_logs diff --git a/src/dstack/_internal/server/services/logs/aws.py b/src/dstack/_internal/server/services/logs/aws.py index 92155763c..616db94db 100644 --- a/src/dstack/_internal/server/services/logs/aws.py +++ b/src/dstack/_internal/server/services/logs/aws.py @@ -17,7 +17,6 @@ from dstack._internal.server.services.logs.base import ( LogStorage, LogStorageError, - b64encode_raw_message, datetime_to_unix_time_ms, unix_time_ms_to_datetime, ) @@ -238,8 +237,7 @@ def _get_next_batch( skipped_future_events += 1 continue cw_event = self._runner_log_event_to_cloudwatch_event(event) - # as message is base64-encoded, length in bytes = length in code points. - message_size = len(cw_event["message"]) + self.MESSAGE_OVERHEAD_SIZE + message_size = len(event.message) + self.MESSAGE_OVERHEAD_SIZE if message_size > self.MESSAGE_MAX_SIZE: # we should never hit this limit, as we use `io.Copy` to copy from pty to logs, # which under the hood uses 32KiB buffer, see runner/internal/executor/executor.go, @@ -271,7 +269,7 @@ def _runner_log_event_to_cloudwatch_event( ) -> _CloudWatchLogEvent: return { "timestamp": runner_log_event.timestamp, - "message": b64encode_raw_message(runner_log_event.message), + "message": runner_log_event.message.decode(errors="replace"), } @contextmanager diff --git a/src/dstack/_internal/server/services/logs/filelog.py b/src/dstack/_internal/server/services/logs/filelog.py index 905ee3527..10cbe3d1a 100644 --- a/src/dstack/_internal/server/services/logs/filelog.py +++ b/src/dstack/_internal/server/services/logs/filelog.py @@ -2,6 +2,7 @@ from typing import List, Union from uuid import UUID +from dstack._internal.core.errors import ServerClientError from dstack._internal.core.models.logs import ( JobSubmissionLogs, LogEvent, @@ -14,8 +15,6 @@ from dstack._internal.server.schemas.runner import LogEvent as RunnerLogEvent from dstack._internal.server.services.logs.base import ( LogStorage, - LogStorageError, - b64encode_raw_message, unix_time_ms_to_datetime, ) @@ -30,9 +29,6 @@ def __init__(self, root: Union[Path, str, None] = None) -> None: self.root = Path(root) def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs: - if request.descending: - raise LogStorageError("descending: true is not supported") - log_producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB log_file_path = self._get_log_file_path( project_name=project.name, @@ -46,11 +42,11 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi try: start_line = int(request.next_token) if start_line < 0: - raise LogStorageError( + raise ServerClientError( f"Invalid next_token: {request.next_token}. Must be a non-negative integer." ) except ValueError: - raise LogStorageError( + raise ServerClientError( f"Invalid next_token: {request.next_token}. Must be a valid integer." ) @@ -60,31 +56,41 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi try: with open(log_file_path) as f: - lines = f.readlines() - - for i, line in enumerate(lines): - if current_line < start_line: + # Skip to start_line if needed + for _ in range(start_line): + if f.readline() == "": + # File is shorter than start_line + return JobSubmissionLogs(logs=logs, next_token=next_token) current_line += 1 - continue - log_event = LogEvent.__response__.parse_raw(line) - current_line += 1 + # Read lines one by one + while True: + line = f.readline() + if line == "": # EOF + break + + current_line += 1 - if request.start_time and log_event.timestamp <= request.start_time: - continue - if request.end_time is not None and log_event.timestamp >= request.end_time: - break + try: + log_event = LogEvent.__response__.parse_raw(line) + except Exception: + # Skip malformed lines + continue - logs.append(log_event) + if request.start_time and log_event.timestamp <= request.start_time: + continue + if request.end_time is not None and log_event.timestamp >= request.end_time: + break - if len(logs) >= request.limit: - # Only set next_token if there are more lines to read - if current_line < len(lines): - next_token = str(current_line) - break + logs.append(log_event) - except IOError as e: - raise LogStorageError(f"Failed to read log file {log_file_path}: {e}") + if len(logs) >= request.limit: + # Check if there are more lines to read + if f.readline() != "": + next_token = str(current_line) + break + except FileNotFoundError: + pass return JobSubmissionLogs(logs=logs, next_token=next_token) @@ -140,5 +146,5 @@ def _runner_log_event_to_log_event(self, runner_log_event: RunnerLogEvent) -> Lo return LogEvent( timestamp=unix_time_ms_to_datetime(runner_log_event.timestamp), log_source=LogEventSource.STDOUT, - message=b64encode_raw_message(runner_log_event.message), + message=runner_log_event.message.decode(errors="replace"), ) diff --git a/src/dstack/_internal/server/services/logs/gcp.py b/src/dstack/_internal/server/services/logs/gcp.py index ac228e19e..6e9314df2 100644 --- a/src/dstack/_internal/server/services/logs/gcp.py +++ b/src/dstack/_internal/server/services/logs/gcp.py @@ -14,7 +14,6 @@ from dstack._internal.server.services.logs.base import ( LogStorage, LogStorageError, - b64encode_raw_message, unix_time_ms_to_datetime, ) from dstack._internal.utils.common import batched @@ -137,15 +136,14 @@ def _write_logs_to_stream(self, stream_name: str, logs: List[RunnerLogEvent]): with self.logger.batch() as batcher: for batch in batched(logs, self.MAX_BATCH_SIZE): for log in batch: - message = b64encode_raw_message(log.message) + message = log.message.decode(errors="replace") timestamp = unix_time_ms_to_datetime(log.timestamp) - # as message is base64-encoded, length in bytes = length in code points - if len(message) > self.MAX_RUNNER_MESSAGE_SIZE: + if len(log.message) > self.MAX_RUNNER_MESSAGE_SIZE: logger.error( "Stream %s: skipping event at %s, message exceeds max size: %d > %d", stream_name, timestamp.isoformat(), - len(message), + len(log.message), self.MAX_RUNNER_MESSAGE_SIZE, ) continue diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py index 4578cb56f..7c21e840b 100644 --- a/src/dstack/_internal/server/services/proxy/repo.py +++ b/src/dstack/_internal/server/services/proxy/repo.py @@ -12,10 +12,12 @@ from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams from dstack._internal.core.models.runs import ( JobProvisioningData, + JobSpec, JobStatus, RunSpec, RunStatus, ServiceSpec, + get_service_port, ) from dstack._internal.core.models.services import AnyModel from dstack._internal.proxy.lib.models import ( @@ -97,9 +99,10 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic if rci.ssh_proxy is not None: ssh_head_proxy = rci.ssh_proxy ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private + job_spec: JobSpec = JobSpec.__response__.parse_raw(job.job_spec_data) replica = Replica( id=job.id.hex, - app_port=run_spec.configuration.port.container_port, + app_port=get_service_port(job_spec, run_spec.configuration), ssh_destination=ssh_destination, ssh_port=ssh_port, ssh_proxy=ssh_proxy, diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 79d9e0a20..33b7c8299 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -24,6 +24,7 @@ ) from dstack._internal.core.models.profiles import ( CreationPolicy, + RetryEvent, ) from dstack._internal.core.models.repos.virtual import DEFAULT_VIRTUAL_REPO_ID, VirtualRunRepoData from dstack._internal.core.models.runs import ( @@ -105,6 +106,8 @@ async def list_user_runs( repo_id: Optional[str], username: Optional[str], only_active: bool, + include_jobs: bool, + job_submissions_limit: Optional[int], prev_submitted_at: Optional[datetime], prev_run_id: Optional[uuid.UUID], limit: int, @@ -148,7 +151,14 @@ async def list_user_runs( runs = [] for r in run_models: try: - runs.append(run_model_to_run(r, return_in_api=True)) + runs.append( + run_model_to_run( + r, + return_in_api=True, + include_jobs=include_jobs, + job_submissions_limit=job_submissions_limit, + ) + ) except pydantic.ValidationError: pass if len(run_models) > len(runs): @@ -652,51 +662,33 @@ async def delete_runs( def run_model_to_run( run_model: RunModel, - include_job_submissions: bool = True, + include_jobs: bool = True, + job_submissions_limit: Optional[int] = None, return_in_api: bool = False, include_sensitive: bool = False, ) -> Run: jobs: List[Job] = [] - run_jobs = sorted(run_model.jobs, key=lambda j: (j.replica_num, j.job_num, j.submission_num)) - for replica_num, replica_submissions in itertools.groupby( - run_jobs, key=lambda j: j.replica_num - ): - for job_num, job_submissions in itertools.groupby( - replica_submissions, key=lambda j: j.job_num - ): - submissions = [] - job_model = None - for job_model in job_submissions: - if include_job_submissions: - job_submission = job_model_to_job_submission(job_model) - if return_in_api: - # Set default non-None values for 0.18 backward-compatibility - # Remove in 0.19 - if job_submission.job_provisioning_data is not None: - if job_submission.job_provisioning_data.hostname is None: - job_submission.job_provisioning_data.hostname = "" - if job_submission.job_provisioning_data.ssh_port is None: - job_submission.job_provisioning_data.ssh_port = 22 - submissions.append(job_submission) - if job_model is not None: - # Use the spec from the latest submission. Submissions can have different specs - job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) - if not include_sensitive: - _remove_job_spec_sensitive_info(job_spec) - jobs.append(Job(job_spec=job_spec, job_submissions=submissions)) + if include_jobs: + jobs = _get_run_jobs_with_submissions( + run_model=run_model, + job_submissions_limit=job_submissions_limit, + return_in_api=return_in_api, + include_sensitive=include_sensitive, + ) run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) latest_job_submission = None - if include_job_submissions: + if len(jobs) > 0 and len(jobs[0].job_submissions) > 0: # TODO(egor-s): does it make sense with replicas and multi-node? - if jobs: - latest_job_submission = jobs[0].job_submissions[-1] + latest_job_submission = jobs[0].job_submissions[-1] service_spec = None if run_model.service_spec is not None: service_spec = ServiceSpec.__response__.parse_raw(run_model.service_spec) + status_message = _get_run_status_message(run_model) + error = _get_run_error(run_model) run = Run( id=run_model.id, project_name=run_model.project.name, @@ -704,18 +696,107 @@ def run_model_to_run( submitted_at=run_model.submitted_at.replace(tzinfo=timezone.utc), last_processed_at=run_model.last_processed_at.replace(tzinfo=timezone.utc), status=run_model.status, + status_message=status_message, termination_reason=run_model.termination_reason, run_spec=run_spec, jobs=jobs, latest_job_submission=latest_job_submission, service=service_spec, deployment_num=run_model.deployment_num, + error=error, deleted=run_model.deleted, ) run.cost = _get_run_cost(run) return run +def _get_run_jobs_with_submissions( + run_model: RunModel, + job_submissions_limit: Optional[int], + return_in_api: bool = False, + include_sensitive: bool = False, +) -> List[Job]: + jobs: List[Job] = [] + run_jobs = sorted(run_model.jobs, key=lambda j: (j.replica_num, j.job_num, j.submission_num)) + for replica_num, replica_submissions in itertools.groupby( + run_jobs, key=lambda j: j.replica_num + ): + for job_num, job_models in itertools.groupby(replica_submissions, key=lambda j: j.job_num): + submissions = [] + job_model = None + if job_submissions_limit is not None: + if job_submissions_limit == 0: + # Take latest job submission to return its job_spec + job_models = list(job_models)[-1:] + else: + job_models = list(job_models)[-job_submissions_limit:] + for job_model in job_models: + if job_submissions_limit != 0: + job_submission = job_model_to_job_submission(job_model) + if return_in_api: + # Set default non-None values for 0.18 backward-compatibility + # Remove in 0.19 + if job_submission.job_provisioning_data is not None: + if job_submission.job_provisioning_data.hostname is None: + job_submission.job_provisioning_data.hostname = "" + if job_submission.job_provisioning_data.ssh_port is None: + job_submission.job_provisioning_data.ssh_port = 22 + submissions.append(job_submission) + if job_model is not None: + # Use the spec from the latest submission. Submissions can have different specs + job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) + if not include_sensitive: + _remove_job_spec_sensitive_info(job_spec) + jobs.append(Job(job_spec=job_spec, job_submissions=submissions)) + return jobs + + +def _get_run_status_message(run_model: RunModel) -> str: + if len(run_model.jobs) == 0: + return run_model.status.value + + sorted_job_models = sorted( + run_model.jobs, key=lambda j: (j.replica_num, j.job_num, j.submission_num) + ) + job_models_grouped_by_job = list( + list(jm) + for _, jm in itertools.groupby(sorted_job_models, key=lambda j: (j.replica_num, j.job_num)) + ) + + if all(job_models[-1].status == JobStatus.PULLING for job_models in job_models_grouped_by_job): + # Show `pulling`` if last job submission of all jobs is pulling + return "pulling" + + if run_model.status in [RunStatus.SUBMITTED, RunStatus.PENDING]: + # Show `retrying` if any job caused the run to retry + for job_models in job_models_grouped_by_job: + last_job_spec = JobSpec.__response__.parse_raw(job_models[-1].job_spec_data) + retry_on_events = last_job_spec.retry.on_events if last_job_spec.retry else [] + last_job_termination_reason = _get_last_job_termination_reason(job_models) + if ( + last_job_termination_reason + == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY + and RetryEvent.NO_CAPACITY in retry_on_events + ): + # TODO: Show `retrying` for other retry events + return "retrying" + + return run_model.status.value + + +def _get_last_job_termination_reason(job_models: List[JobModel]) -> Optional[JobTerminationReason]: + for job_model in reversed(job_models): + if job_model.termination_reason is not None: + return job_model.termination_reason + return None + + +def _get_run_error(run_model: RunModel) -> Optional[str]: + if run_model.termination_reason is None: + return None + return run_model.termination_reason.to_error() + + async def _get_pool_offers( session: AsyncSession, project: ProjectModel, @@ -914,6 +995,8 @@ def _validate_run_spec_and_set_defaults(run_spec: RunSpec): "replicas", "scaling", # rolling deployment + # NOTE: keep this list in sync with the "Rolling deployment" section in services.md + "port", "resources", "volumes", "docker", diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 062390b9a..ba3168c59 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -22,7 +22,7 @@ from dstack._internal.core.models.configurations import SERVICE_HTTPS_DEFAULT, ServiceConfiguration from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.runs import Run, RunSpec, ServiceModelSpec, ServiceSpec +from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec from dstack._internal.server import settings from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel from dstack._internal.server.services.gateways import ( @@ -179,6 +179,7 @@ async def register_replica( async with conn.client() as client: await client.register_replica( run=run, + job_spec=JobSpec.__response__.parse_raw(job_model.job_spec_data), job_submission=job_submission, ssh_head_proxy=ssh_head_proxy, ssh_head_proxy_private_key=ssh_head_proxy_private_key, diff --git a/src/dstack/_internal/server/services/users.py b/src/dstack/_internal/server/services/users.py index 7aaf4b979..8504f2af7 100644 --- a/src/dstack/_internal/server/services/users.py +++ b/src/dstack/_internal/server/services/users.py @@ -44,7 +44,9 @@ async def list_users_for_user( session: AsyncSession, user: UserModel, ) -> List[User]: - return await list_all_users(session=session) + if user.global_role == GlobalRole.ADMIN: + return await list_all_users(session=session) + return [user_model_to_user(user)] async def list_all_users( diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index 9a7ed53d3..be43e02ce 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -401,6 +401,19 @@ def _validate_volume_configuration(configuration: VolumeConfiguration): if configuration.name is not None: validate_dstack_resource_name(configuration.name) + if configuration.volume_id is not None and configuration.auto_cleanup_duration is not None: + if ( + isinstance(configuration.auto_cleanup_duration, int) + and configuration.auto_cleanup_duration > 0 + ) or ( + isinstance(configuration.auto_cleanup_duration, str) + and configuration.auto_cleanup_duration not in ("off", "-1") + ): + raise ServerClientError( + "External volumes (with volume_id) do not support auto_cleanup_duration. " + "Auto-cleanup only works for volumes created and managed by dstack." + ) + async def _delete_volume(session: AsyncSession, project: ProjectModel, volume_model: VolumeModel): volume = volume_model_to_volume(volume_model) diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index a60fc5e65..f025d7164 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -42,6 +42,11 @@ os.getenv("DSTACK_SERVER_BACKGROUND_PROCESSING_FACTOR", 1) ) +SERVER_BACKGROUND_PROCESSING_DISABLED = ( + os.getenv("DSTACK_SERVER_BACKGROUND_PROCESSING_DISABLED") is not None +) +SERVER_BACKGROUND_PROCESSING_ENABLED = not SERVER_BACKGROUND_PROCESSING_DISABLED + SERVER_EXECUTOR_MAX_WORKERS = int(os.getenv("DSTACK_SERVER_EXECUTOR_MAX_WORKERS", 128)) MAX_OFFERS_TRIED = int(os.getenv("DSTACK_SERVER_MAX_OFFERS_TRIED", 25)) @@ -113,5 +118,5 @@ UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_UPDATE_DEFAULT_PROJECT") is not None DO_NOT_UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT") is not None -SKIP_GATEWAY_UPDATE = os.getenv("DSTACK_SKIP_GATEWAY_UPDATE", None) is not None -ENABLE_PROMETHEUS_METRICS = os.getenv("DSTACK_ENABLE_PROMETHEUS_METRICS", None) is not None +SKIP_GATEWAY_UPDATE = os.getenv("DSTACK_SKIP_GATEWAY_UPDATE") is not None +ENABLE_PROMETHEUS_METRICS = os.getenv("DSTACK_ENABLE_PROMETHEUS_METRICS") is not None diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 047adb5c1..02d3ac242 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -31,6 +31,8 @@ FleetSpec, FleetStatus, InstanceGroupPlacement, + SSHHostParams, + SSHParams, ) from dstack._internal.core.models.gateways import GatewayComputeConfiguration, GatewayStatus from dstack._internal.core.models.instances import ( @@ -378,6 +380,7 @@ def get_job_provisioning_data( hostname: str = "127.0.0.4", internal_ip: Optional[str] = "127.0.0.4", price: float = 10.5, + instance_type: Optional[InstanceType] = None, ) -> JobProvisioningData: gpus = [ Gpu( @@ -386,14 +389,16 @@ def get_job_provisioning_data( vendor=gpuhunt.AcceleratorVendor.NVIDIA, ) ] * gpu_count - return JobProvisioningData( - backend=backend, - instance_type=InstanceType( + if instance_type is None: + instance_type = InstanceType( name="instance", resources=Resources( cpus=cpu_count, memory_mib=int(memory_gib * 1024), spot=spot, gpus=gpus ), - ), + ) + return JobProvisioningData( + backend=backend, + instance_type=instance_type, instance_id="instance_id", hostname=hostname, internal_ip=internal_ip, @@ -549,6 +554,31 @@ def get_fleet_configuration( ) +def get_ssh_fleet_configuration( + name: str = "test-fleet", + user: str = "ubuntu", + ssh_key: Optional[SSHKey] = None, + hosts: Optional[list[Union[SSHHostParams, str]]] = None, + network: Optional[str] = None, + placement: Optional[InstanceGroupPlacement] = None, +) -> FleetConfiguration: + if ssh_key is None: + ssh_key = SSHKey(public="", private=get_private_key_string()) + if hosts is None: + hosts = ["10.0.0.100"] + ssh_config = SSHParams( + user=user, + ssh_key=ssh_key, + hosts=hosts, + network=network, + ) + return FleetConfiguration( + name=name, + ssh_config=ssh_config, + placement=placement, + ) + + async def create_instance( session: AsyncSession, project: ProjectModel, @@ -590,7 +620,9 @@ async def create_instance( internal_ip=None, ) if offer == "auto": - offer = get_instance_offer_with_availability(backend=backend, region=region, spot=spot) + offer = get_instance_offer_with_availability( + backend=backend, region=region, spot=spot, price=price + ) if profile is None: profile = Profile(name="test_name") @@ -742,6 +774,7 @@ async def create_volume( status: VolumeStatus = VolumeStatus.SUBMITTED, created_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), last_processed_at: Optional[datetime] = None, + last_job_processed_at: Optional[datetime] = None, configuration: Optional[VolumeConfiguration] = None, volume_provisioning_data: Optional[VolumeProvisioningData] = None, deleted_at: Optional[datetime] = None, @@ -759,6 +792,7 @@ async def create_volume( status=status, created_at=created_at, last_processed_at=last_processed_at, + last_job_processed_at=last_job_processed_at, configuration=configuration.json(), volume_provisioning_data=volume_provisioning_data.json() if volume_provisioning_data @@ -820,6 +854,7 @@ def get_volume_configuration( region: str = "eu-west-1", size: Optional[Memory] = Memory(100), volume_id: Optional[str] = None, + auto_cleanup_duration: Optional[Union[str, int]] = None, ) -> VolumeConfiguration: return VolumeConfiguration( name=name, @@ -827,6 +862,7 @@ def get_volume_configuration( region=region, size=size, volume_id=volume_id, + auto_cleanup_duration=auto_cleanup_duration, ) diff --git a/src/dstack/_internal/server/utils/routers.py b/src/dstack/_internal/server/utils/routers.py index f8ca21004..131ec5cc3 100644 --- a/src/dstack/_internal/server/utils/routers.py +++ b/src/dstack/_internal/server/utils/routers.py @@ -1,11 +1,34 @@ -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional -from fastapi import HTTPException, Request, status -from fastapi.responses import JSONResponse +import orjson +from fastapi import HTTPException, Request, Response, status from packaging import version from dstack._internal.core.errors import ServerClientError, ServerClientErrorCode from dstack._internal.core.models.common import CoreModel +from dstack._internal.utils.json_utils import get_orjson_default_options, orjson_default + + +class CustomORJSONResponse(Response): + """ + Custom JSONResponse that uses orjson for serialization. + + It's recommended to return this class from routers directly instead of + returning pydantic models to avoid the FastAPI's jsonable_encoder overhead. + See https://fastapi.tiangolo.com/advanced/custom-response/#use-orjsonresponse. + + Beware that FastAPI skips model validation when responses are returned directly. + If serialization needs to be modified, override `dict()` instead of adding validators. + """ + + media_type = "application/json" + + def render(self, content: Any) -> bytes: + return orjson.dumps( + content, + option=get_orjson_default_options(), + default=orjson_default, + ) class BadRequestDetailsModel(CoreModel): @@ -30,7 +53,7 @@ def get_base_api_additional_responses() -> Dict: """ Returns additional responses for the OpenAPI docs relevant to all API endpoints. The endpoints may override responses to make them as specific as possible. - E.g. an enpoint may specify which error codes it may return in `code`. + E.g. an endpoint may specify which error codes it may return in `code`. """ return { 400: get_bad_request_additional_response(), @@ -102,7 +125,7 @@ def get_request_size(request: Request) -> int: def check_client_server_compatibility( client_version: Optional[str], server_version: Optional[str], -) -> Optional[JSONResponse]: +) -> Optional[CustomORJSONResponse]: """ Returns `JSONResponse` with error if client/server versions are incompatible. Returns `None` otherwise. @@ -116,7 +139,7 @@ def check_client_server_compatibility( try: parsed_client_version = version.parse(client_version) except version.InvalidVersion: - return JSONResponse( + return CustomORJSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={ "detail": get_server_client_error_details( @@ -138,11 +161,11 @@ def error_incompatible_versions( client_version: Optional[str], server_version: str, ask_cli_update: bool, -) -> JSONResponse: +) -> CustomORJSONResponse: msg = f"The client/CLI version ({client_version}) is incompatible with the server version ({server_version})." if ask_cli_update: msg += f" Update the dstack CLI: `pip install dstack=={server_version}`." - return JSONResponse( + return CustomORJSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"detail": get_server_client_error_details(ServerClientError(msg=msg))}, ) diff --git a/src/dstack/_internal/utils/json_utils.py b/src/dstack/_internal/utils/json_utils.py new file mode 100644 index 000000000..9017e94c3 --- /dev/null +++ b/src/dstack/_internal/utils/json_utils.py @@ -0,0 +1,54 @@ +from typing import Any + +import orjson +from pydantic import BaseModel + +FREEZEGUN = True +try: + from freezegun.api import FakeDatetime +except ImportError: + FREEZEGUN = False + + +ASYNCPG = True +try: + import asyncpg.pgproto.pgproto +except ImportError: + ASYNCPG = False + + +def pydantic_orjson_dumps(v: Any, *, default: Any) -> str: + return orjson.dumps( + v, + option=get_orjson_default_options(), + default=orjson_default, + ).decode() + + +def pydantic_orjson_dumps_with_indent(v: Any, *, default: Any) -> str: + return orjson.dumps( + v, + option=get_orjson_default_options() | orjson.OPT_INDENT_2, + default=orjson_default, + ).decode() + + +def orjson_default(obj): + if isinstance(obj, float): + # orjson does not convert float subclasses be default + return float(obj) + if isinstance(obj, BaseModel): + # Allows calling orjson.dumps() on pydantic models + # (e.g. to return from the API) + return obj.dict() + if ASYNCPG: + if isinstance(obj, asyncpg.pgproto.pgproto.UUID): + return str(obj) + if FREEZEGUN: + if isinstance(obj, FakeDatetime): + return obj.isoformat() + raise TypeError + + +def get_orjson_default_options() -> int: + return orjson.OPT_NON_STR_KEYS diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 1a4e0e1e2..e1992068d 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -18,7 +18,11 @@ from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_RUNNER_SSH_PORT from dstack._internal.core.errors import ClientError, ConfigurationError, ResourceNotExistsError from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.configurations import AnyRunConfiguration, PortMapping +from dstack._internal.core.models.configurations import ( + AnyRunConfiguration, + PortMapping, + ServiceConfiguration, +) from dstack._internal.core.models.files import FileArchiveMapping, FilePathMapping from dstack._internal.core.models.profiles import ( CreationPolicy, @@ -38,6 +42,7 @@ RunPlan, RunSpec, RunStatus, + get_service_port, ) from dstack._internal.core.models.runs import Run as RunModel from dstack._internal.core.services.logs import URLReplacer @@ -163,7 +168,7 @@ def ws_thread(): service_port = 443 if secure else 80 ports = { **ports, - self._run.run_spec.configuration.port.container_port: service_port, + get_or_error(get_or_error(self._ssh_attach).service_port): service_port, } path_prefix = url.path replace_urls = URLReplacer( @@ -338,6 +343,10 @@ def attach( else: container_user = "root" + service_port = None + if isinstance(self._run.run_spec.configuration, ServiceConfiguration): + service_port = get_service_port(job.job_spec, self._run.run_spec.configuration) + self._ssh_attach = SSHAttach( hostname=provisioning_data.hostname, ssh_port=provisioning_data.ssh_port, @@ -349,6 +358,7 @@ def attach( run_name=name, dockerized=provisioning_data.dockerized, ssh_proxy=provisioning_data.ssh_proxy, + service_port=service_port, local_backend=provisioning_data.backend == BackendType.LOCAL, bind_address=bind_address, ) @@ -748,6 +758,7 @@ def list(self, all: bool = False, limit: Optional[int] = None) -> List[Run]: repo_id=None, only_active=only_active, limit=limit or 100, + # TODO: Pass job_submissions_limit=1 in 0.20 ) if only_active and len(runs) == 0: runs = self._api_client.runs.list( diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index 2c85792eb..745ce9c78 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -4,7 +4,11 @@ from pydantic import parse_obj_as -from dstack._internal.core.compatibility.runs import get_apply_plan_excludes, get_get_plan_excludes +from dstack._internal.core.compatibility.runs import ( + get_apply_plan_excludes, + get_get_plan_excludes, + get_list_runs_excludes, +) from dstack._internal.core.models.runs import ( ApplyRunPlanInput, Run, @@ -33,18 +37,24 @@ def list( prev_run_id: Optional[UUID] = None, limit: int = 100, ascending: bool = False, + include_jobs: bool = True, + job_submissions_limit: Optional[int] = None, ) -> List[Run]: body = ListRunsRequest( project_name=project_name, repo_id=repo_id, username=username, only_active=only_active, + include_jobs=include_jobs, + job_submissions_limit=job_submissions_limit, prev_submitted_at=prev_submitted_at, prev_run_id=prev_run_id, limit=limit, ascending=ascending, ) - resp = self._request("/api/runs/list", body=body.json()) + resp = self._request( + "/api/runs/list", body=body.json(exclude=get_list_runs_excludes(body)) + ) return parse_obj_as(List[Run.__response__], resp.json()) def get(self, project_name: str, run_name: str) -> Run: diff --git a/src/tests/_internal/core/models/test_runs.py b/src/tests/_internal/core/models/test_runs.py index 851cba9e3..23b27c018 100644 --- a/src/tests/_internal/core/models/test_runs.py +++ b/src/tests/_internal/core/models/test_runs.py @@ -1,9 +1,7 @@ from dstack._internal.core.models.profiles import RetryEvent from dstack._internal.core.models.runs import ( JobStatus, - JobSubmission, JobTerminationReason, - Run, RunStatus, RunTerminationReason, ) @@ -33,8 +31,9 @@ def test_job_termination_reason_to_retry_event_works_with_all_enum_variants(): assert retry_event is None or isinstance(retry_event, RetryEvent) -# Will fail if JobTerminationReason value is added without updaing JobSubmission._get_error +# Will fail if JobTerminationReason value is added without updating JobSubmission._get_error def test_get_error_returns_expected_messages(): + # already handled and shown in status_message no_error_reasons = [ JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY, JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY, @@ -47,7 +46,7 @@ def test_get_error_returns_expected_messages(): ] for reason in JobTerminationReason: - if JobSubmission._get_error(reason) is None: + if reason.to_error() is None: # Fail no-error reason is not in the list assert reason in no_error_reasons @@ -62,6 +61,6 @@ def test_run_get_error_returns_none_for_specific_reasons(): ] for reason in RunTerminationReason: - if Run._get_error(reason) is None: + if reason.to_error() is None: # Fail no-error reason is not in the list assert reason in no_error_reasons diff --git a/src/tests/_internal/server/background/tasks/test_process_gateways.py b/src/tests/_internal/server/background/tasks/test_process_gateways.py index 159547af4..3460f18cb 100644 --- a/src/tests/_internal/server/background/tasks/test_process_gateways.py +++ b/src/tests/_internal/server/background/tasks/test_process_gateways.py @@ -5,56 +5,48 @@ from dstack._internal.core.errors import BackendError from dstack._internal.core.models.gateways import GatewayProvisioningData, GatewayStatus -from dstack._internal.server.background.tasks.process_gateways import process_submitted_gateways +from dstack._internal.server.background.tasks.process_gateways import process_gateways from dstack._internal.server.testing.common import ( AsyncContextManager, ComputeMockSpec, create_backend, create_gateway, + create_gateway_compute, create_project, ) +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestProcessSubmittedGateways: - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_provisions_gateway(self, test_db, session: AsyncSession): + async def test_submitted_to_provisioning(self, test_db, session: AsyncSession): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, + status=GatewayStatus.SUBMITTED, ) - with ( - patch( - "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" - ) as m, - patch( - "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add" - ) as pool_add, - ): + with patch( + "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" + ) as m: aws = Mock() m.return_value = (backend, aws) - pool_add.return_value = MagicMock() - pool_add.return_value.client.return_value = MagicMock(AsyncContextManager()) aws.compute.return_value = Mock(spec=ComputeMockSpec) aws.compute.return_value.create_gateway.return_value = GatewayProvisioningData( instance_id="i-1234567890", ip_address="2.2.2.2", region="us", ) - await process_submitted_gateways() + await process_gateways() m.assert_called_once() aws.compute.return_value.create_gateway.assert_called_once() - pool_add.assert_called_once() await session.refresh(gateway) - assert gateway.status == GatewayStatus.RUNNING + assert gateway.status == GatewayStatus.PROVISIONING assert gateway.gateway_compute is not None assert gateway.gateway_compute.ip_address == "2.2.2.2" - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_marks_gateway_as_failed_if_gateway_creation_errors( self, test_db, session: AsyncSession ): @@ -64,6 +56,7 @@ async def test_marks_gateway_as_failed_if_gateway_creation_errors( session=session, project_id=project.id, backend_id=backend.id, + status=GatewayStatus.SUBMITTED, ) with patch( "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" @@ -72,47 +65,57 @@ async def test_marks_gateway_as_failed_if_gateway_creation_errors( m.return_value = (backend, aws) aws.compute.return_value = Mock(spec=ComputeMockSpec) aws.compute.return_value.create_gateway.side_effect = BackendError("Some error") - await process_submitted_gateways() + await process_gateways() m.assert_called_once() aws.compute.return_value.create_gateway.assert_called_once() await session.refresh(gateway) assert gateway.status == GatewayStatus.FAILED assert gateway.status_message == "Some error" - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestProcessProvisioningGateways: + async def test_provisioning_to_running(self, test_db, session: AsyncSession): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway_compute = await create_gateway_compute(session) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.PROVISIONING, + ) + with patch( + "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add" + ) as pool_add: + pool_add.return_value = MagicMock() + pool_add.return_value.client.return_value = MagicMock(AsyncContextManager()) + await process_gateways() + pool_add.assert_called_once() + await session.refresh(gateway) + assert gateway.status == GatewayStatus.RUNNING + async def test_marks_gateway_as_failed_if_fails_to_connect( self, test_db, session: AsyncSession ): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) + gateway_compute = await create_gateway_compute(session) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.PROVISIONING, ) - with ( - patch( - "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" - ) as m, - patch( - "dstack._internal.server.services.gateways.connect_to_gateway_with_retry" - ) as connect_to_gateway_with_retry_mock, - ): - aws = Mock() - m.return_value = (backend, aws) + with patch( + "dstack._internal.server.services.gateways.connect_to_gateway_with_retry" + ) as connect_to_gateway_with_retry_mock: connect_to_gateway_with_retry_mock.return_value = None - aws.compute.return_value = Mock(spec=ComputeMockSpec) - aws.compute.return_value.create_gateway.return_value = GatewayProvisioningData( - instance_id="i-1234567890", - ip_address="2.2.2.2", - region="us", - ) - await process_submitted_gateways() - m.assert_called_once() - aws.compute.return_value.create_gateway.assert_called_once() + await process_gateways() connect_to_gateway_with_retry_mock.assert_called_once() await session.refresh(gateway) assert gateway.status == GatewayStatus.FAILED - assert gateway.gateway_compute is not None - assert gateway.gateway_compute is not None + assert gateway.status_message == "Failed to connect to gateway" diff --git a/src/tests/_internal/server/background/tasks/test_process_idle_volumes.py b/src/tests/_internal/server/background/tasks/test_process_idle_volumes.py new file mode 100644 index 000000000..8f6621186 --- /dev/null +++ b/src/tests/_internal/server/background/tasks/test_process_idle_volumes.py @@ -0,0 +1,190 @@ +import datetime +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.volumes import VolumeStatus +from dstack._internal.server.background.tasks.process_idle_volumes import ( + _get_idle_time, + _should_delete_volume, + process_idle_volumes, +) +from dstack._internal.server.models import VolumeAttachmentModel +from dstack._internal.server.testing.common import ( + ComputeMockSpec, + create_instance, + create_project, + create_user, + create_volume, + get_volume_configuration, + get_volume_provisioning_data, +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestProcessIdleVolumes: + async def test_deletes_idle_volumes(self, test_db, session: AsyncSession): + project = await create_project(session=session) + user = await create_user(session=session) + + config1 = get_volume_configuration( + name="test-volume", + auto_cleanup_duration="1h", + ) + config2 = get_volume_configuration( + name="test-volume", + auto_cleanup_duration="3h", + ) + volume1 = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + configuration=config1, + volume_provisioning_data=get_volume_provisioning_data(), + last_job_processed_at=datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(hours=2), + ) + volume2 = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + configuration=config2, + volume_provisioning_data=get_volume_provisioning_data(), + last_job_processed_at=datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(hours=2), + ) + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backend_by_type_or_error" + ) as m: + aws_mock = Mock() + m.return_value = aws_mock + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + await process_idle_volumes() + + await session.refresh(volume1) + await session.refresh(volume2) + assert volume1.deleted + assert volume1.deleted_at is not None + assert not volume2.deleted + assert volume2.deleted_at is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestShouldDeleteVolume: + async def test_no_idle_duration(self, test_db, session: AsyncSession): + project = await create_project(session=session) + user = await create_user(session=session) + + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + configuration=get_volume_configuration(name="test-volume"), + volume_provisioning_data=get_volume_provisioning_data(), + ) + + assert not _should_delete_volume(volume) + + async def test_idle_duration_disabled(self, test_db, session: AsyncSession): + project = await create_project(session=session) + user = await create_user(session=session) + + config = get_volume_configuration(name="test-volume") + config.auto_cleanup_duration = -1 + + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + configuration=config, + volume_provisioning_data=get_volume_provisioning_data(), + ) + + assert not _should_delete_volume(volume) + + async def test_volume_attached(self, test_db, session: AsyncSession): + project = await create_project(session=session) + user = await create_user(session=session) + + config = get_volume_configuration(name="test-volume") + config.auto_cleanup_duration = "1h" + + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + configuration=config, + volume_provisioning_data=get_volume_provisioning_data(), + ) + + instance = await create_instance(session=session, project=project) + volume.attachments.append( + VolumeAttachmentModel(volume_id=volume.id, instance_id=instance.id) + ) + await session.commit() + + assert not _should_delete_volume(volume) + + async def test_idle_duration_threshold(self, test_db, session: AsyncSession): + project = await create_project(session=session) + user = await create_user(session=session) + + config = get_volume_configuration(name="test-volume") + config.auto_cleanup_duration = "1h" + + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + configuration=config, + volume_provisioning_data=get_volume_provisioning_data(), + ) + + # Not exceeded - 30 minutes ago + volume.last_job_processed_at = ( + datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(minutes=30) + ).replace(tzinfo=None) + assert not _should_delete_volume(volume) + + # Exceeded - 2 hours ago + volume.last_job_processed_at = ( + datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=2) + ).replace(tzinfo=None) + assert _should_delete_volume(volume) + + async def test_never_used_volume(self, test_db, session: AsyncSession): + project = await create_project(session=session) + user = await create_user(session=session) + + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + configuration=get_volume_configuration(name="test-volume"), + volume_provisioning_data=get_volume_provisioning_data(), + created_at=datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=2), + ) + + volume.last_job_processed_at = None + idle_time = _get_idle_time(volume) + assert idle_time.total_seconds() >= 7000 diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index 87a970c73..a0ae8ce72 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -10,7 +10,12 @@ from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.fleets import FleetConfiguration, FleetStatus, SSHParams +from dstack._internal.core.models.fleets import ( + FleetConfiguration, + FleetStatus, + InstanceGroupPlacement, + SSHParams, +) from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceOfferWithAvailability, @@ -21,6 +26,7 @@ ) from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.server.models import FleetModel, InstanceModel +from dstack._internal.server.services.fleets import fleet_model_to_fleet from dstack._internal.server.services.permissions import DefaultPermissions from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( @@ -35,7 +41,11 @@ get_auth_headers, get_fleet_configuration, get_fleet_spec, + get_instance_offer_with_availability, + get_job_provisioning_data, get_private_key_string, + get_remote_connection_info, + get_ssh_fleet_configuration, ) pytestmark = pytest.mark.usefixtures("image_config_mock") @@ -415,17 +425,14 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A await add_project_member( session=session, project=project, user=user, project_role=ProjectRole.USER ) - spec = get_fleet_spec( - conf=FleetConfiguration( - name="test-ssh-fleet", - ssh_config=SSHParams( - user="ubuntu", - ssh_key=SSHKey(public="", private=get_private_key_string()), - hosts=["1.1.1.1"], - network=None, - ), - ) + conf = get_ssh_fleet_configuration( + name="test-ssh-fleet", + user="ubuntu", + ssh_key=SSHKey(public="", private=get_private_key_string()), + hosts=["1.1.1.1"], + network=None, ) + spec = get_fleet_spec(conf=conf) with patch("uuid.uuid4") as m: m.return_value = UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e") response = await client.post( @@ -541,6 +548,212 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A instance = res.unique().scalar_one() assert instance.remote_connection_info is not None + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @freeze_time(datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), real_asyncio=True) + async def test_updates_ssh_fleet(self, test_db, session: AsyncSession, client: AsyncClient): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + current_conf = get_ssh_fleet_configuration( + name="test-ssh-fleet", + user="ubuntu", + ssh_key=SSHKey(public="", private=get_private_key_string()), + hosts=["10.0.0.100"], + network=None, + ) + current_spec = get_fleet_spec(conf=current_conf) + spec = current_spec.copy(deep=True) + # 10.0.0.100 removed, 10.0.0.101 added + spec.configuration.ssh_config.hosts = ["10.0.0.101"] + + fleet = await create_fleet(session=session, project=project, spec=current_spec) + instance_type = InstanceType( + name="ssh", + resources=Resources(cpus=2, memory_mib=8, gpus=[], spot=False), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + backend=BackendType.REMOTE, + name="test-ssh-fleet-0", + region="remote", + price=0.0, + status=InstanceStatus.IDLE, + offer=get_instance_offer_with_availability( + backend=BackendType.REMOTE, + region="remote", + price=0.0, + ), + job_provisioning_data=get_job_provisioning_data( + instance_type=instance_type, + hostname="10.0.0.100", + ), + remote_connection_info=get_remote_connection_info(host="10.0.0.100"), + ) + + with patch("uuid.uuid4") as m: + m.return_value = UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e") + response = await client.post( + f"/api/project/{project.name}/fleets/apply", + headers=get_auth_headers(user.token), + json={ + "plan": { + "spec": spec.dict(), + "current_resource": _fleet_model_to_json_dict(fleet), + }, + "force": False, + }, + ) + + assert response.status_code == 200, response.json() + assert response.json() == { + "id": str(fleet.id), + "name": spec.configuration.name, + "project_name": project.name, + "spec": { + "configuration_path": spec.configuration_path, + "configuration": { + "env": {}, + "ssh_config": { + "user": "ubuntu", + "port": None, + "identity_file": None, + "ssh_key": None, # should not return ssh_key + "proxy_jump": None, + "hosts": ["10.0.0.101"], + "network": None, + }, + "nodes": None, + "placement": None, + "resources": { + "cpu": {"min": 2, "max": None}, + "memory": {"min": 8.0, "max": None}, + "shm_size": None, + "gpu": None, + "disk": {"size": {"min": 100.0, "max": None}}, + }, + "backends": None, + "regions": None, + "availability_zones": None, + "instance_types": None, + "spot_policy": None, + "retry": None, + "max_price": None, + "idle_duration": None, + "type": "fleet", + "name": spec.configuration.name, + "reservation": None, + "blocks": 1, + "tags": None, + }, + "profile": { + "backends": None, + "regions": None, + "availability_zones": None, + "instance_types": None, + "spot_policy": None, + "retry": None, + "max_duration": None, + "stop_duration": None, + "max_price": None, + "creation_policy": None, + "idle_duration": None, + "utilization_policy": None, + "startup_order": None, + "stop_criteria": None, + "name": "", + "default": False, + "reservation": None, + "fleets": None, + "tags": None, + }, + "autocreated": False, + }, + "created_at": "2023-01-02T03:04:00+00:00", + "status": "active", + "status_message": None, + "instances": [ + { + "id": str(instance.id), + "project_name": project.name, + "backend": "remote", + "instance_type": { + "name": "ssh", + "resources": { + "cpu_arch": None, + "cpus": 2, + "memory_mib": 8, + "gpus": [], + "spot": False, + "disk": {"size_mib": 102400}, + "description": "cpu=2 mem=0GB disk=100GB", + }, + }, + "name": "test-ssh-fleet-0", + "fleet_id": str(fleet.id), + "fleet_name": "test-ssh-fleet", + "instance_num": 0, + "job_name": None, + "hostname": "10.0.0.100", + "status": "terminating", + "unreachable": False, + "termination_reason": None, + "created": "2023-01-02T03:04:00+00:00", + "region": "remote", + "availability_zone": None, + "price": 0.0, + "total_blocks": 1, + "busy_blocks": 0, + }, + { + "id": "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e", + "project_name": project.name, + "backend": "remote", + "instance_type": { + "name": "ssh", + "resources": { + "cpu_arch": None, + "cpus": 2, + "memory_mib": 8, + "gpus": [], + "spot": False, + "disk": {"size_mib": 102400}, + "description": "cpu=2 mem=0GB disk=100GB", + }, + }, + "name": "test-ssh-fleet-1", + "fleet_id": str(fleet.id), + "fleet_name": "test-ssh-fleet", + "instance_num": 1, + "job_name": None, + "hostname": "10.0.0.101", + "status": "pending", + "unreachable": False, + "termination_reason": None, + "created": "2023-01-02T03:04:00+00:00", + "region": "remote", + "availability_zone": None, + "price": 0.0, + "total_blocks": 1, + "busy_blocks": 0, + }, + ], + } + res = await session.execute(select(FleetModel)) + assert res.scalar_one() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATING + res = await session.execute( + select(InstanceModel).where(InstanceModel.id == "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e") + ) + instance = res.unique().scalar_one() + assert instance.status == InstanceStatus.PENDING + assert instance.remote_connection_info is not None + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @freeze_time(datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc)) @@ -820,7 +1033,9 @@ async def test_returns_40x_if_not_authenticated( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_returns_plan(self, test_db, session: AsyncSession, client: AsyncClient): + async def test_returns_create_plan_for_new_fleet( + self, test_db, session: AsyncSession, client: AsyncClient + ): user = await create_user(session=session, global_role=GlobalRole.USER) project = await create_project(session=session, owner=user) await add_project_member( @@ -855,10 +1070,91 @@ async def test_returns_plan(self, test_db, session: AsyncSession, client: AsyncC assert response.json() == { "project_name": project.name, "user": user.name, - "spec": spec.dict(), - "effective_spec": spec.dict(), + "spec": json.loads(spec.json()), + "effective_spec": json.loads(spec.json()), "current_resource": None, "offers": [json.loads(o.json()) for o in offers], "total_offers": len(offers), "max_offer_price": 1.0, + "action": "create", } + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_update_plan_for_existing_fleet( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + conf = get_ssh_fleet_configuration(hosts=["10.0.0.100"]) + spec = get_fleet_spec(conf=conf) + effective_spec = spec.copy(deep=True) + effective_spec.configuration.ssh_config.ssh_key = None + current_spec = spec.copy(deep=True) + # `hosts` can be updated in-place + current_spec.configuration.ssh_config.hosts = ["10.0.0.100", "10.0.0.101"] + fleet = await create_fleet(session=session, project=project, spec=current_spec) + + response = await client.post( + f"/api/project/{project.name}/fleets/get_plan", + headers=get_auth_headers(user.token), + json={"spec": spec.dict()}, + ) + + assert response.status_code == 200 + assert response.json() == { + "project_name": project.name, + "user": user.name, + "spec": spec.dict(), + "effective_spec": effective_spec.dict(), + "current_resource": _fleet_model_to_json_dict(fleet), + "offers": [], + "total_offers": 0, + "max_offer_price": None, + "action": "update", + } + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_create_plan_for_existing_fleet( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + conf = get_ssh_fleet_configuration(placement=InstanceGroupPlacement.ANY) + spec = get_fleet_spec(conf=conf) + effective_spec = spec.copy(deep=True) + effective_spec.configuration.ssh_config.ssh_key = None + current_spec = spec.copy(deep=True) + # `placement` cannot be updated in-place + current_spec.configuration.placement = InstanceGroupPlacement.CLUSTER + fleet = await create_fleet(session=session, project=project, spec=current_spec) + + response = await client.post( + f"/api/project/{project.name}/fleets/get_plan", + headers=get_auth_headers(user.token), + json={"spec": spec.dict()}, + ) + + assert response.status_code == 200 + assert response.json() == { + "project_name": project.name, + "user": user.name, + "spec": spec.dict(), + "effective_spec": effective_spec.dict(), + "current_resource": _fleet_model_to_json_dict(fleet), + "offers": [], + "total_offers": 0, + "max_offer_price": None, + "action": "create", + } + + +def _fleet_model_to_json_dict(fleet: FleetModel) -> dict: + return json.loads(fleet_model_to_fleet(fleet).json()) diff --git a/src/tests/_internal/server/routers/test_logs.py b/src/tests/_internal/server/routers/test_logs.py index 11f0da8da..0364edee4 100644 --- a/src/tests/_internal/server/routers/test_logs.py +++ b/src/tests/_internal/server/routers/test_logs.py @@ -62,17 +62,17 @@ async def test_returns_logs( { "timestamp": "2023-10-06T10:01:53.234234+00:00", "log_source": "stdout", - "message": "Hello", + "message": "SGVsbG8=", }, { "timestamp": "2023-10-06T10:01:53.234235+00:00", "log_source": "stdout", - "message": "World", + "message": "V29ybGQ=", }, { "timestamp": "2023-10-06T10:01:53.234236+00:00", "log_source": "stdout", - "message": "!", + "message": "IQ==", }, ], "next_token": None, @@ -93,7 +93,7 @@ async def test_returns_logs( { "timestamp": "2023-10-06T10:01:53.234236+00:00", "log_source": "stdout", - "message": "!", + "message": "IQ==", }, ], "next_token": None, diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 96adb6499..4ec0c4225 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -246,6 +246,7 @@ def get_dev_env_run_plan_dict( "repo_code_hash": None, "repo_data": {"repo_dir": "/repo", "repo_type": "local"}, "file_archives": [], + "service_port": None, }, "offers": [json.loads(o.json()) for o in offers], "total_offers": total_offers, @@ -441,6 +442,7 @@ def get_dev_env_run_dict( "repo_code_hash": None, "repo_data": {"repo_dir": "/repo", "repo_type": "local"}, "file_archives": [], + "service_port": None, }, "job_submissions": [ { @@ -707,6 +709,108 @@ async def test_lists_runs_pagination( assert len(response2_json) == 1 assert response2_json[0]["id"] == str(run2.id) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_limits_job_submissions( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run_submitted_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + submitted_at=run_submitted_at, + ) + run_spec = RunSpec.parse_raw(run.run_spec) + await create_job( + session=session, + run=run, + submitted_at=run_submitted_at, + last_processed_at=run_submitted_at, + ) + job2 = await create_job( + session=session, + run=run, + submitted_at=run_submitted_at, + last_processed_at=run_submitted_at, + ) + job2_spec = JobSpec.parse_raw(job2.job_spec_data) + response = await client.post( + "/api/runs/list", + headers=get_auth_headers(user.token), + json={"job_submissions_limit": 1}, + ) + assert response.status_code == 200, response.json() + assert response.json() == [ + { + "id": str(run.id), + "project_name": project.name, + "user": user.name, + "submitted_at": run_submitted_at.isoformat(), + "last_processed_at": run_submitted_at.isoformat(), + "status": "submitted", + "status_message": "submitted", + "run_spec": run_spec.dict(), + "jobs": [ + { + "job_spec": job2_spec.dict(), + "job_submissions": [ + { + "id": str(job2.id), + "submission_num": 0, + "deployment_num": 0, + "submitted_at": run_submitted_at.isoformat(), + "last_processed_at": run_submitted_at.isoformat(), + "finished_at": None, + "inactivity_secs": None, + "status": "submitted", + "status_message": "submitted", + "termination_reason": None, + "termination_reason_message": None, + "error": None, + "exit_status": None, + "job_provisioning_data": None, + "job_runtime_data": None, + } + ], + } + ], + "latest_job_submission": { + "id": str(job2.id), + "submission_num": 0, + "deployment_num": 0, + "submitted_at": run_submitted_at.isoformat(), + "last_processed_at": run_submitted_at.isoformat(), + "finished_at": None, + "inactivity_secs": None, + "status": "submitted", + "status_message": "submitted", + "termination_reason_message": None, + "termination_reason": None, + "error": None, + "exit_status": None, + "job_provisioning_data": None, + "job_runtime_data": None, + }, + "cost": 0, + "service": None, + "deployment_num": 0, + "termination_reason": None, + "error": None, + "deleted": False, + }, + ] + class TestGetRun: @pytest.mark.asyncio @@ -1074,12 +1178,14 @@ async def test_returns_run_plan_instance_volumes( ServiceConfiguration( commands=["one", "two"], port=80, + gateway=None, replicas=1, scaling=None, ), ServiceConfiguration( commands=["one", "two"], - port=8080, # not updatable + port=8080, + gateway="test-gateway", # not updatable replicas="2..4", scaling=ScalingSpec(metric="rps", target=5), ), diff --git a/src/tests/_internal/server/routers/test_users.py b/src/tests/_internal/server/routers/test_users.py index d19290749..df4a8f3c0 100644 --- a/src/tests/_internal/server/routers/test_users.py +++ b/src/tests/_internal/server/routers/test_users.py @@ -22,19 +22,71 @@ async def test_returns_40x_if_not_authenticated(self, test_db, client: AsyncClie @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_returns_users(self, test_db, session: AsyncSession, client: AsyncClient): - user = await create_user( + async def test_admins_see_all_users(self, test_db, session: AsyncSession, client: AsyncClient): + admin = await create_user( + session=session, + name="admin", + created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + global_role=GlobalRole.ADMIN, + ) + other_user = await create_user( + session=session, + name="other_user", + created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + global_role=GlobalRole.USER, + ) + response = await client.post("/api/users/list", headers=get_auth_headers(admin.token)) + assert response.status_code in [200] + assert response.json() == [ + { + "id": str(admin.id), + "username": admin.name, + "created_at": "2023-01-02T03:04:00+00:00", + "global_role": admin.global_role, + "email": None, + "active": True, + "permissions": { + "can_create_projects": True, + }, + }, + { + "id": str(other_user.id), + "username": other_user.name, + "created_at": "2023-01-02T03:04:00+00:00", + "global_role": other_user.global_role, + "email": None, + "active": True, + "permissions": { + "can_create_projects": True, + }, + }, + ] + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_non_admins_see_only_themselves( + self, test_db, session: AsyncSession, client: AsyncClient + ): + await create_user( + session=session, + name="admin", + created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + global_role=GlobalRole.ADMIN, + ) + other_user = await create_user( session=session, + name="other_user", created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + global_role=GlobalRole.USER, ) - response = await client.post("/api/users/list", headers=get_auth_headers(user.token)) + response = await client.post("/api/users/list", headers=get_auth_headers(other_user.token)) assert response.status_code in [200] assert response.json() == [ { - "id": str(user.id), - "username": user.name, + "id": str(other_user.id), + "username": other_user.name, "created_at": "2023-01-02T03:04:00+00:00", - "global_role": user.global_role, + "global_role": other_user.global_role, "email": None, "active": True, "permissions": { diff --git a/src/tests/_internal/server/services/test_logs.py b/src/tests/_internal/server/services/test_logs.py index 19769a360..0b9420917 100644 --- a/src/tests/_internal/server/services/test_logs.py +++ b/src/tests/_internal/server/services/test_logs.py @@ -1,4 +1,3 @@ -import base64 import logging from datetime import datetime, timedelta, timezone from pathlib import Path @@ -13,6 +12,7 @@ from pydantic import ValidationError from sqlalchemy.ext.asyncio import AsyncSession +from dstack._internal.core.errors import ServerClientError from dstack._internal.core.models.logs import LogEvent, LogEventSource from dstack._internal.server.models import ProjectModel from dstack._internal.server.schemas.logs import PollLogsRequest @@ -51,8 +51,8 @@ async def test_writes_logs(self, test_db, session: AsyncSession, tmp_path: Path) / "runner.log" ) assert runner_log_path.read_text() == ( - '{"timestamp": "2023-10-06T10:01:53.234000+00:00", "log_source": "stdout", "message": "SGVsbG8="}\n' - '{"timestamp": "2023-10-06T10:01:53.235000+00:00", "log_source": "stdout", "message": "V29ybGQ="}\n' + '{"timestamp":"2023-10-06T10:01:53.234000+00:00","log_source":"stdout","message":"Hello"}\n' + '{"timestamp":"2023-10-06T10:01:53.235000+00:00","log_source":"stdout","message":"World"}\n' ) @pytest.mark.asyncio @@ -119,12 +119,8 @@ async def test_poll_logs_with_next_token_pagination( job_submission_logs = log_storage.poll_logs(project, poll_request) assert len(job_submission_logs.logs) == 2 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log1".encode("utf-8") - ).decode("utf-8") - assert job_submission_logs.logs[1].message == base64.b64encode( - "Log2".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log1" + assert job_submission_logs.logs[1].message == "Log2" assert job_submission_logs.next_token == "2" # Next line to read # Second page: use next_token @@ -132,12 +128,8 @@ async def test_poll_logs_with_next_token_pagination( job_submission_logs = log_storage.poll_logs(project, poll_request) assert len(job_submission_logs.logs) == 2 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log3".encode("utf-8") - ).decode("utf-8") - assert job_submission_logs.logs[1].message == base64.b64encode( - "Log4".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log3" + assert job_submission_logs.logs[1].message == "Log4" assert job_submission_logs.next_token == "4" # Next line to read # Third page: get remaining log @@ -145,9 +137,7 @@ async def test_poll_logs_with_next_token_pagination( job_submission_logs = log_storage.poll_logs(project, poll_request) assert len(job_submission_logs.logs) == 1 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log5".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log5" assert job_submission_logs.next_token is None # No more logs @pytest.mark.asyncio @@ -182,12 +172,8 @@ async def test_poll_logs_with_start_from_specific_line( job_submission_logs = log_storage.poll_logs(project, poll_request) assert len(job_submission_logs.logs) == 2 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log2".encode("utf-8") - ).decode("utf-8") - assert job_submission_logs.logs[1].message == base64.b64encode( - "Log3".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log2" + assert job_submission_logs.logs[1].message == "Log3" assert job_submission_logs.next_token is None @pytest.mark.asyncio @@ -206,49 +192,22 @@ async def test_poll_logs_invalid_next_token_raises_error( limit=10, diagnose=True, ) - with pytest.raises( - LogStorageError, match="Invalid next_token: invalid. Must be a valid integer." - ): + with pytest.raises(ServerClientError): log_storage.poll_logs(project, poll_request) # Test with negative next_token poll_request.next_token = "-1" - with pytest.raises( - LogStorageError, match="Invalid next_token: -1. Must be a non-negative integer." - ): + with pytest.raises(ServerClientError): log_storage.poll_logs(project, poll_request) # Test with float next_token poll_request.next_token = "1.5" - with pytest.raises( - LogStorageError, match="Invalid next_token: 1.5. Must be a valid integer." - ): + with pytest.raises(ServerClientError): log_storage.poll_logs(project, poll_request) @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_poll_logs_descending_raises_error( - self, test_db, session: AsyncSession, tmp_path: Path - ): - project = await create_project(session=session) - log_storage = FileLogStorage(tmp_path) - - # Test that descending=True raises LogStorageError - poll_request = PollLogsRequest( - run_name="test_run", - job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), - limit=10, - diagnose=True, - # Note: This bypasses schema validation for testing the implementation - ) - poll_request.descending = True # Set directly to bypass validation - - with pytest.raises(LogStorageError, match="descending: true is not supported"): - log_storage.poll_logs(project, poll_request) - - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_poll_logs_file_not_found_raises_error( + async def test_poll_logs_file_not_found_raises_no_error( self, test_db, session: AsyncSession, tmp_path: Path ): project = await create_project(session=session) @@ -261,11 +220,7 @@ async def test_poll_logs_file_not_found_raises_error( limit=10, diagnose=True, ) - - with pytest.raises( - LogStorageError, match="Failed to read log file .* No such file or directory" - ): - log_storage.poll_logs(project, poll_request) + log_storage.poll_logs(project, poll_request) @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -309,9 +264,7 @@ async def test_poll_logs_with_time_filtering_and_pagination( # Should get Log3 first (timestamp > 235) assert len(job_submission_logs.logs) == 1 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log3".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log3" assert job_submission_logs.next_token == "3" # Get next page @@ -320,9 +273,7 @@ async def test_poll_logs_with_time_filtering_and_pagination( # Should get Log4 assert len(job_submission_logs.logs) == 1 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log4".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log4" # Should not have next_token since we reached end of file assert job_submission_logs.next_token is None @@ -389,9 +340,9 @@ async def test_next_token_pagination_complete_workflow( page1 = log_storage.poll_logs(project, poll_request) assert len(page1.logs) == 3 - assert page1.logs[0].message == base64.b64encode("Log1".encode()).decode() - assert page1.logs[1].message == base64.b64encode("Log2".encode()).decode() - assert page1.logs[2].message == base64.b64encode("Log3".encode()).decode() + assert page1.logs[0].message == "Log1" + assert page1.logs[1].message == "Log2" + assert page1.logs[2].message == "Log3" assert page1.next_token == "3" # Next line to read # Second page: use next_token @@ -399,9 +350,9 @@ async def test_next_token_pagination_complete_workflow( page2 = log_storage.poll_logs(project, poll_request) assert len(page2.logs) == 3 - assert page2.logs[0].message == base64.b64encode("Log4".encode()).decode() - assert page2.logs[1].message == base64.b64encode("Log5".encode()).decode() - assert page2.logs[2].message == base64.b64encode("Log6".encode()).decode() + assert page2.logs[0].message == "Log4" + assert page2.logs[1].message == "Log5" + assert page2.logs[2].message == "Log6" assert page2.next_token == "6" # Third page: get more logs @@ -409,9 +360,9 @@ async def test_next_token_pagination_complete_workflow( page3 = log_storage.poll_logs(project, poll_request) assert len(page3.logs) == 3 - assert page3.logs[0].message == base64.b64encode("Log7".encode()).decode() - assert page3.logs[1].message == base64.b64encode("Log8".encode()).decode() - assert page3.logs[2].message == base64.b64encode("Log9".encode()).decode() + assert page3.logs[0].message == "Log7" + assert page3.logs[1].message == "Log8" + assert page3.logs[2].message == "Log9" assert page3.next_token == "9" # Fourth page: get last log @@ -419,8 +370,8 @@ async def test_next_token_pagination_complete_workflow( page4 = log_storage.poll_logs(project, poll_request) assert len(page4.logs) == 1 - assert page4.logs[0].message == base64.b64encode("Log10".encode()).decode() - assert page4.next_token is None # No more logs + assert page4.logs[0].message == "Log10" + assert page4.next_token is None @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -458,15 +409,15 @@ async def test_next_token_with_time_filtering( page1 = log_storage.poll_logs(project, poll_request) assert len(page1.logs) == 2 - assert page1.logs[0].message == base64.b64encode("Log3".encode()).decode() - assert page1.logs[1].message == base64.b64encode("Log4".encode()).decode() + assert page1.logs[0].message == "Log3" + assert page1.logs[1].message == "Log4" assert page1.next_token == "4" # Get next page poll_request.next_token = page1.next_token page2 = log_storage.poll_logs(project, poll_request) assert len(page2.logs) == 1 - assert page2.logs[0].message == base64.b64encode("Log5".encode()).decode() + assert page2.logs[0].message == "Log5" assert page2.next_token is None @pytest.mark.asyncio @@ -497,16 +448,16 @@ async def test_next_token_edge_cases(self, test_db, session: AsyncSession, tmp_p result = log_storage.poll_logs(project, poll_request) assert len(result.logs) == 1 - assert result.logs[0].message == base64.b64encode("OnlyLog".encode()).decode() - assert result.next_token is None # No more logs available + assert result.logs[0].message == "OnlyLog" + assert result.next_token is None # Request with limit equal to available logs poll_request.limit = 1 result = log_storage.poll_logs(project, poll_request) assert len(result.logs) == 1 - assert result.logs[0].message == base64.b64encode("OnlyLog".encode()).decode() - assert result.next_token is None # No more logs available + assert result.logs[0].message == "OnlyLog" + assert result.next_token is None @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -607,15 +558,9 @@ async def test_poll_logs_with_limit(self, test_db, session: AsyncSession, tmp_pa # Should return only the first 3 logs and provide next_token assert len(job_submission_logs.logs) == 3 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log1".encode("utf-8") - ).decode("utf-8") - assert job_submission_logs.logs[1].message == base64.b64encode( - "Log2".encode("utf-8") - ).decode("utf-8") - assert job_submission_logs.logs[2].message == base64.b64encode( - "Log3".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log1" + assert job_submission_logs.logs[1].message == "Log2" + assert job_submission_logs.logs[2].message == "Log3" # Should have next_token pointing to line 3 (fourth log) assert job_submission_logs.next_token == "3" @@ -624,9 +569,7 @@ async def test_poll_logs_with_limit(self, test_db, session: AsyncSession, tmp_pa poll_request.start_time = logs[3].timestamp job_submission_logs = log_storage.poll_logs(project, poll_request) assert len(job_submission_logs.logs) == 1 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log5".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log5" # Should not have next_token since we reached end of file assert job_submission_logs.next_token is None @@ -951,14 +894,14 @@ async def test_write_logs( logGroupName="test-group", logStreamName=expected_runner_stream, logEvents=[ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513234, "message": "Hello"}, ], ), call( logGroupName="test-group", logStreamName=expected_job_stream, logEvents=[ - {"timestamp": 1696586513235, "message": "V29ybGQ="}, + {"timestamp": 1696586513235, "message": "World"}, ], ), ] @@ -1056,11 +999,11 @@ async def test_write_logs_not_in_chronological_order( logGroupName="test-group", logStreamName="test-proj/test-run/1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e/runner", logEvents=[ - {"timestamp": 1696586513235, "message": "MQ=="}, - {"timestamp": 1696586513236, "message": "Mg=="}, - {"timestamp": 1696586513237, "message": "Mw=="}, - {"timestamp": 1696586513237, "message": "NA=="}, - {"timestamp": 1696586513237, "message": "NQ=="}, + {"timestamp": 1696586513235, "message": "1"}, + {"timestamp": 1696586513236, "message": "2"}, + {"timestamp": 1696586513237, "message": "3"}, + {"timestamp": 1696586513237, "message": "4"}, + {"timestamp": 1696586513237, "message": "5"}, ], ) assert "events are not in chronological order" in caplog.text @@ -1098,7 +1041,7 @@ def _delta_ms(**kwargs: int) -> int: assert "skipping 1 past event(s)" in caplog.text assert "skipping 2 future event(s)" in caplog.text actual = [ - base64.b64decode(e["message"]).decode() + e["message"] for c in mock_client.put_log_events.call_args_list for e in c.kwargs["logEvents"] ] @@ -1143,8 +1086,8 @@ async def test_write_logs_batching_by_size( messages: List[str], expected: List[List[str]], ): - # maximum 6 bytes: 12 (in base64) + 26 (overhead) = 34 - monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 34) + # maximum 6 bytes: 6 (raw bytes) + 26 (overhead) = 32 + monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 32) monkeypatch.setattr(CloudWatchLogStorage, "BATCH_MAX_SIZE", 60) log_storage.write_logs( project=project, @@ -1158,7 +1101,7 @@ async def test_write_logs_batching_by_size( ) assert mock_client.put_log_events.call_count == len(expected) actual = [ - [base64.b64decode(e["message"]).decode() for e in c.kwargs["logEvents"]] + [e["message"] for e in c.kwargs["logEvents"]] for c in mock_client.put_log_events.call_args_list ] assert actual == expected @@ -1173,7 +1116,7 @@ async def test_write_logs_batching_by_size( [["111", "111", "111"], ["222"]], ], [ - ["111", "111", "111"] + ["222", "222", "toolong", "", "222222"], + ["111", "111", "111"] + ["222", "222", "toolongtoolong", "", "222222"], [["111", "111", "111"], ["222", "222", "222222"]], ], ], @@ -1190,8 +1133,8 @@ async def test_write_logs_batching_by_count( messages: List[str], expected: List[List[str]], ): - # maximum 6 bytes: 12 (in base64) + 26 (overhead) = 34 - monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 34) + # maximum 6 bytes: 6 (raw bytes) + 26 (overhead) = 32 + monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 32) monkeypatch.setattr(CloudWatchLogStorage, "EVENT_MAX_COUNT_IN_BATCH", 3) log_storage.write_logs( project=project, @@ -1205,7 +1148,7 @@ async def test_write_logs_batching_by_count( ) assert mock_client.put_log_events.call_count == len(expected) actual = [ - [base64.b64decode(e["message"]).decode() for e in c.kwargs["logEvents"]] + [e["message"] for e in c.kwargs["logEvents"]] for c in mock_client.put_log_events.call_args_list ] assert actual == expected @@ -1248,7 +1191,7 @@ def _delta_ms(**kwargs: int) -> int: expected = [["1", "2", "3"], ["4", "5", "6"], ["7"]] assert mock_client.put_log_events.call_count == len(expected) actual = [ - [base64.b64decode(e["message"]).decode() for e in c.kwargs["logEvents"]] + [e["message"] for e in c.kwargs["logEvents"]] for c in mock_client.put_log_events.call_args_list ] assert actual == expected @@ -1262,8 +1205,8 @@ async def test_poll_logs_non_empty_response( poll_logs_request: PollLogsRequest, ): mock_client.get_log_events.return_value["events"] = [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, - {"timestamp": 1696586513235, "message": "V29ybGQ="}, + {"timestamp": 1696586513234, "message": "Hello"}, + {"timestamp": 1696586513235, "message": "World"}, ] poll_logs_request.limit = 2 job_submission_logs = log_storage.poll_logs(project, poll_logs_request) @@ -1272,12 +1215,12 @@ async def test_poll_logs_non_empty_response( LogEvent( timestamp=datetime(2023, 10, 6, 10, 1, 53, 234000, tzinfo=timezone.utc), log_source=LogEventSource.STDOUT, - message="SGVsbG8=", + message="Hello", ), LogEvent( timestamp=datetime(2023, 10, 6, 10, 1, 53, 235000, tzinfo=timezone.utc), log_source=LogEventSource.STDOUT, - message="V29ybGQ=", + message="World", ), ] @@ -1290,8 +1233,8 @@ async def test_poll_logs_descending_non_empty_response_on_first_call( poll_logs_request: PollLogsRequest, ): mock_client.get_log_events.return_value["events"] = [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, - {"timestamp": 1696586513235, "message": "V29ybGQ="}, + {"timestamp": 1696586513234, "message": "Hello"}, + {"timestamp": 1696586513235, "message": "World"}, ] poll_logs_request.descending = True poll_logs_request.limit = 2 @@ -1301,12 +1244,12 @@ async def test_poll_logs_descending_non_empty_response_on_first_call( LogEvent( timestamp=datetime(2023, 10, 6, 10, 1, 53, 235000, tzinfo=timezone.utc), log_source=LogEventSource.STDOUT, - message="V29ybGQ=", + message="World", ), LogEvent( timestamp=datetime(2023, 10, 6, 10, 1, 53, 234000, tzinfo=timezone.utc), log_source=LogEventSource.STDOUT, - message="SGVsbG8=", + message="Hello", ), ] @@ -1322,8 +1265,8 @@ async def test_next_token_ascending_pagination( # Setup response with nextForwardToken mock_client.get_log_events.return_value = { "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, - {"timestamp": 1696586513235, "message": "V29ybGQ="}, + {"timestamp": 1696586513234, "message": "Hello"}, + {"timestamp": 1696586513235, "message": "World"}, ], "nextBackwardToken": "bwd", "nextForwardToken": "fwd123", @@ -1357,8 +1300,8 @@ async def test_next_token_descending_pagination( # Setup response with nextBackwardToken mock_client.get_log_events.return_value = { "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, - {"timestamp": 1696586513235, "message": "V29ybGQ="}, + {"timestamp": 1696586513234, "message": "Hello"}, + {"timestamp": 1696586513235, "message": "World"}, ], "nextBackwardToken": "bwd456", "nextForwardToken": "fwd", @@ -1370,8 +1313,8 @@ async def test_next_token_descending_pagination( assert len(result.logs) == 2 # Events should be reversed for descending order - assert result.logs[0].message == "V29ybGQ=" - assert result.logs[1].message == "SGVsbG8=" + assert result.logs[0].message == "World" + assert result.logs[1].message == "Hello" assert result.next_token == "bwd456" # Should return nextBackwardToken # Verify API was called with correct parameters @@ -1393,7 +1336,7 @@ async def test_next_token_provided_in_request( """Test that provided next_token is passed to CloudWatch API""" mock_client.get_log_events.return_value = { "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513234, "message": "Hello"}, ], "nextBackwardToken": "bwd", "nextForwardToken": "new_fwd", @@ -1449,7 +1392,7 @@ async def test_next_token_with_time_filtering( """Test next_token behavior with time filtering""" mock_client.get_log_events.return_value = { "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513234, "message": "Hello"}, ], "nextBackwardToken": "bwd_with_time", "nextForwardToken": "fwd_with_time", @@ -1487,7 +1430,7 @@ async def test_next_token_missing_in_cloudwatch_response( """Test behavior when CloudWatch doesn't return next tokens""" mock_client.get_log_events.return_value = { "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513234, "message": "Hello"}, ], # No nextBackwardToken or nextForwardToken in response } @@ -1509,7 +1452,7 @@ async def test_next_token_empty_string_in_cloudwatch_response( """Test behavior when CloudWatch returns empty string tokens""" mock_client.get_log_events.return_value = { "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513234, "message": "Hello"}, ], "nextBackwardToken": "", "nextForwardToken": "", @@ -1534,8 +1477,8 @@ async def test_next_token_pagination_workflow( mock_client.get_log_events.side_effect = [ { "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, - {"timestamp": 1696586513235, "message": "V29ybGQ="}, + {"timestamp": 1696586513234, "message": "Hello"}, + {"timestamp": 1696586513235, "message": "World"}, ], "nextBackwardToken": "bwd", "nextForwardToken": "token_page2", @@ -1543,7 +1486,7 @@ async def test_next_token_pagination_workflow( # Second call - returns final logs without next_token { "events": [ - {"timestamp": 1696586513236, "message": "IQ=="}, + {"timestamp": 1696586513236, "message": "!"}, ], "nextBackwardToken": "final_bwd", "nextForwardToken": "final_fwd", @@ -1556,8 +1499,8 @@ async def test_next_token_pagination_workflow( page1 = log_storage.poll_logs(project, poll_logs_request) assert len(page1.logs) == 2 - assert page1.logs[0].message == "SGVsbG8=" - assert page1.logs[1].message == "V29ybGQ=" + assert page1.logs[0].message == "Hello" + assert page1.logs[1].message == "World" assert page1.next_token == "token_page2" # Second page using next_token @@ -1565,7 +1508,7 @@ async def test_next_token_pagination_workflow( page2 = log_storage.poll_logs(project, poll_logs_request) assert len(page2.logs) == 1 - assert page2.logs[0].message == "IQ==" + assert page2.logs[0].message == "!" assert page2.next_token == "final_fwd" # Verify both API calls diff --git a/src/tests/_internal/server/services/test_volumes.py b/src/tests/_internal/server/services/test_volumes.py index c2ba555a1..4de9c3f05 100644 --- a/src/tests/_internal/server/services/test_volumes.py +++ b/src/tests/_internal/server/services/test_volumes.py @@ -3,11 +3,78 @@ import pytest from freezegun import freeze_time -from dstack._internal.core.models.volumes import VolumeStatus -from dstack._internal.server.services.volumes import _get_volume_cost +from dstack._internal.core.errors import ServerClientError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.volumes import VolumeConfiguration, VolumeStatus +from dstack._internal.server.services.volumes import ( + _get_volume_cost, + _validate_volume_configuration, +) from dstack._internal.server.testing.common import get_volume, get_volume_provisioning_data +class TestValidateVolumeConfiguration: + def test_external_volume_with_auto_cleanup_duration_raises_error(self): + """External volumes (with volume_id) should not allow auto_cleanup_duration""" + config = VolumeConfiguration( + backend=BackendType.AWS, + region="us-east-1", + volume_id="vol-123456", + auto_cleanup_duration="1h", + ) + with pytest.raises( + ServerClientError, match="External volumes.*do not support auto_cleanup_duration" + ): + _validate_volume_configuration(config) + + def test_external_volume_with_auto_cleanup_duration_int_raises_error(self): + """External volumes with integer auto_cleanup_duration should also raise error""" + config = VolumeConfiguration( + backend=BackendType.AWS, + region="us-east-1", + volume_id="vol-123456", + auto_cleanup_duration=3600, + ) + with pytest.raises( + ServerClientError, match="External volumes.*do not support auto_cleanup_duration" + ): + _validate_volume_configuration(config) + + def test_external_volume_with_auto_cleanup_disabled_succeeds(self): + """External volumes with auto_cleanup_duration='off' or -1 should be allowed""" + config1 = VolumeConfiguration( + backend=BackendType.AWS, + region="us-east-1", + volume_id="vol-123456", + auto_cleanup_duration="off", + ) + config2 = VolumeConfiguration( + backend=BackendType.AWS, + region="us-east-1", + volume_id="vol-123456", + auto_cleanup_duration=-1, + ) + # Should not raise any errors + _validate_volume_configuration(config1) + _validate_volume_configuration(config2) + + def test_external_volume_without_auto_cleanup_succeeds(self): + """External volumes without auto_cleanup_duration should be allowed""" + config = VolumeConfiguration( + backend=BackendType.AWS, region="us-east-1", volume_id="vol-123456" + ) + # Should not raise any errors + _validate_volume_configuration(config) + + def test_new_volume_with_auto_cleanup_duration_succeeds(self): + """New volumes (without volume_id) with auto_cleanup_duration should be allowed""" + config = VolumeConfiguration( + backend=BackendType.AWS, region="us-east-1", size=100, auto_cleanup_duration="1h" + ) + # Should not raise any errors + _validate_volume_configuration(config) + + class TestGetVolumeCost: def test_returns_0_when_no_provisioning_data(self): volume = get_volume(provisioning_data=None)