Feature: Make batching automatic in put_state_dict operations#122
Feature: Make batching automatic in put_state_dict operations#122wukaixingxp wants to merge 10 commits intometa-pytorch:mainfrom
Conversation
Implements _put_batch() as an internal method that parallelizes storage operations and batches controller notifications into a single RPC call. This provides the foundation for automatic batching in put_state_dict. Key optimizations: - Parallel asyncio.gather for concurrent storage operations - Single notify_put_batch RPC instead of N individual notify_put calls - Each operation gets its own transport buffer to avoid race conditions Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Implements batched notification endpoint that processes multiple put operations in a single RPC call, reducing round-trip overhead when storing many tensors (e.g., model state_dicts). This complements the _put_batch method in LocalClient and enables efficient automatic batching in put_state_dict operations. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Changes put_state_dict to automatically batch all tensor storage operations internally using _put_batch, eliminating the need for users to manually call a separate batch API. Also enables parallel fetching in get_state_dict. Benefits: - put_state_dict now uses single RPC for all parameters (was N RPCs) - Parallel storage operations improve throughput - get_state_dict fetches all tensors concurrently - Batching is transparent - no API changes for users This addresses reviewer feedback to make batching the default behavior rather than requiring explicit put_batch API calls. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #122 +/- ##
=======================================
Coverage ? 58.43%
=======================================
Files ? 32
Lines ? 3318
Branches ? 0
=======================================
Hits ? 1939
Misses ? 1379
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Weight Sync Benchmark:
|
| Variable | Default | Description |
|---|---|---|
BENCH_MODEL |
synthetic-small |
synthetic-small, synthetic-large, qwen3-4b, qwen3-30b |
BENCH_ITER |
5 |
Number of timed iterations |
BENCH_WARMUP |
1 |
Number of warmup iterations |
BENCH_USE_HF |
0 |
Set to 1 to load real HuggingFace model |
Results
Machine: gpu-dev-9ce1cf6f (1.9 TB RAM, 4x GPU)
Transport: MonarchRPC
State dict: 327 parameters, 7.82 GB (bfloat16, synthetic Qwen3-4B structure)
Iterations: 3 timed + 1 warmup
Feature Branch (batched)
Metric Mean Median Stdev Min Max
----------------------------------------------------------------------
put_state_dict 10.759s 10.819s 0.195s 10.540s 10.917s
get_state_dict 11.862s 11.963s 0.448s 11.372s 12.252s
round_trip 22.621s 22.781s 0.287s 22.289s 22.792s
Throughput Mean Peak
----------------------------------------
put (GB/s) 0.73 0.74
get (GB/s) 0.66 0.69
round_trip (GB/s) 0.35 0.35
Main Branch (sequential)
Metric Mean Median Stdev Min Max
----------------------------------------------------------------------
put_state_dict 19.297s 19.296s 0.016s 19.281s 19.313s
get_state_dict 22.526s 22.302s 0.485s 22.194s 23.083s
round_trip 41.823s 41.598s 0.500s 41.475s 42.396s
Throughput Mean Peak
----------------------------------------
put (GB/s) 0.41 0.41
get (GB/s) 0.35 0.35
round_trip (GB/s) 0.19 0.19
Comparison
| Operation | main | feature | Speedup |
|---|---|---|---|
put_state_dict |
19.30s | 10.76s | 1.79x |
get_state_dict |
22.53s | 11.86s | 1.90x |
| Round-trip | 41.82s | 22.62s | 1.85x |
The feature branch is ~1.85x faster overall for weight sync operations
tests/bench_weight_sync.py
Outdated
| "qwen3-30b": (4096, 11008, 151936, 48), | ||
| } | ||
|
|
||
| HF_MODEL_MAP = { |
There was a problem hiding this comment.
Let's just take the HF model names directly? Don't think mapping is necessary
tests/bench_weight_sync.py
Outdated
| } | ||
|
|
||
|
|
||
| def create_synthetic_state_dict(model_type, dtype=torch.bfloat16): |
There was a problem hiding this comment.
Is this useful? I think in most cases you want to benchmark with a real model, maybe smaller
tests/bench_weight_sync.py
Outdated
| get_times = [] | ||
|
|
||
| for i in range(self.iterations): | ||
| key = f"bench_{i}" |
There was a problem hiding this comment.
Store takes advantage of cached in-place gets / puts. We should keep the key prefix static (maybe just use v0)
tests/bench_weight_sync.py
Outdated
| # Warmup | ||
| print(f"[Actor] Running {self.warmup} warmup iteration(s)...") | ||
| for i in range(self.warmup): | ||
| await ts.put_state_dict(state_dict, f"warmup_{i}") |
There was a problem hiding this comment.
Same comment as below, use the same key so that the allocation caches are warmed up
tests/bench_weight_sync.py
Outdated
|
|
||
| assert model_type in MODEL_CONFIGS, f"Unknown model: {model_type}" | ||
|
|
||
| # Disable transports that may not be available |
There was a problem hiding this comment.
Ideally this isn't done and we just use configurable transport type to the strategy
torchstore/client.py
Outdated
| Args: | ||
| items: Dictionary mapping keys to values to store. | ||
| """ | ||
| if not items: |
torchstore/client.py
Outdated
|
|
||
| latency_tracker = LatencyTracker(f"put_batch:{len(items)}_keys") | ||
|
|
||
| # Select storage volume (all items go to same volume) |
There was a problem hiding this comment.
Is this actually always true? I'm not sure on the flow @LucasLLC
torchstore/client.py
Outdated
|
|
||
| # Select storage volume (all items go to same volume) | ||
| storage_volume_ref = self.strategy.select_storage_volume() | ||
| latency_tracker.track_step("select storage volume") |
There was a problem hiding this comment.
nit: "strategy.select_storage_volume"
| (key, request.meta_only(), storage_volume_ref.volume_id) | ||
| for key, request in put_results | ||
| ] | ||
| await self._controller.notify_put_batch.call(notifications) |
| notifications: List of (key, request, storage_volume_id) tuples. | ||
| """ | ||
| self.assert_initialized() | ||
| for key, request, storage_volume_id in notifications: |
There was a problem hiding this comment.
nit: I think this is just duplicate code from notify_put (single), we should move to helper
| logger = getLogger(__name__) | ||
|
|
||
|
|
||
| async def put_state_dict(store, state_dict, key): |
There was a problem hiding this comment.
Let's expose these as new functions for safer rollout? put_state_dict_batch, get_state_dict_batch
|
DCPParityTest failed. I think that's worth investigating |
| self.warmup = warmup | ||
|
|
||
| @endpoint | ||
| async def run_benchmark(self): |
There was a problem hiding this comment.
This should also have a verification step asserting equality between the original model and retrieved one. Could be a checksum comparison if its too large to torch.equal... not sure
There was a problem hiding this comment.
added, now I got this using shared memory:
warnings.warn('resource_tracker: There appear to be %d '
TORCHSTORE_RDMA_ENABLED=0 USE_TORCHCOMMS_RDMA=0 \
BENCH_ITER=2 BENCH_WARMUP=1 \
pytest tests/bench_weight_sync.py -vv -s
============================================== test session starts ==============================================
platform linux -- Python 3.12.12, pytest-9.0.2, pluggy-1.6.0 -- /home/dev/.conda/envs/vllm/bin/python3.12
cachedir: .pytest_cache
rootdir: /home/dev/kai/torchstore
configfile: pyproject.toml
plugins: anyio-4.12.1, asyncio-1.3.0, unordered-0.7.0, typeguard-4.4.4
asyncio: mode=Mode.STRICT, debug=False, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collecting ... Warning: setting HYPERACTOR_CODEC_MAX_FRAME_LENGTH since this needs to be set to enable large RPC calls via Monarch
collected 1 item
tests/bench_weight_sync.py::test_benchmark_weight_sync
Model config: hidden=2560, intermediate=8960, vocab=151936, layers=36
======================================================================
WEIGHT SYNC BENCHMARK
======================================================================
Model: Qwen/Qwen3-4B
Transport: SharedMemory
Max concurrent: 8
Iterations: 2 (+ 1 warmup)
======================================================================
Initializing TorchStoreStrategy with default_transport_type=<TransportType.SharedMemory: 6>
Monarch internal logs are being written to /tmp/dev/monarch_log.log; execution id dev_Feb-10_22:46_43
[Actor] Loading HuggingFace model: Qwen/Qwen3-4B...
[actor=<root>.<tests.bench_weight_sync.BenchmarkActor bench_actor_f0a82f01{'gpus': 0/1}>] `torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|███████████████████████████████████████████████████| 3/3 [00:00<00:00, 78.90it/s]
[Actor] State dict: 399 parameters, 8.22 GB
[Actor] Running 1 warmup iteration(s)...
[Actor] Warmup complete.
[Actor] Running 2 timed iteration(s)...
Iter 1: put=0.938s (8.76 GB/s), get=1.033s (7.95 GB/s), total=1.971s
Iter 2: put=0.920s (8.93 GB/s), get=1.032s (7.96 GB/s), total=1.953s
[Actor] Verifying correctness of last get...
[Actor] Verification passed.
======================================================================
RESULTS SUMMARY
======================================================================
Parameters: 399
State dict size: 8414.1 MB (8.22 GB)
======================================================================
Metric Mean Median Stdev Min Max
----------------------------------------------------------------------
put_state_dict_batch 0.929s 0.929s 0.012s 0.920s 0.938s
get_state_dict_batch 1.033s 1.033s 0.001s 1.032s 1.033s
round_trip 1.962s 1.962s 0.013s 1.953s 1.971s
Throughput Mean Peak
----------------------------------------
put (GB/s) 8.85 8.93
get (GB/s) 7.96 7.96
round_trip (GB/s) 4.19 4.21
======================================================================
PASSED
============================================== 1 passed in 43.69s =============================================== ```
LucasLLC
left a comment
There was a problem hiding this comment.
@wukaixingxp do you see any value in maintaining the old way of doing put/get state dict? shouldn't we just always do this?
TYSM! Looking forward to seeing these perf gains!
Make batching automatic in put_state_dict operations
Summary
Makes batching completely automatic inside
put_state_dict()operations - no API changes required, batching happens transparently.
Key Changes
_put_batch()- Parallelizes storage + batchescontroller notifications
notify_put_batch()- Single RPC instead of Nindividual calls
put_state_dict()- Uses_put_batch()internally
get_state_dict()- Fetches tensors concurrentlyPerformance Impact
put_state_dict: 1 RPC instead of N RPCs for notificationsDependencies
Independent of PRs #120 and #121 - can merge in any order