Skip to content

Feature: Make batching automatic in put_state_dict operations#122

Open
wukaixingxp wants to merge 10 commits intometa-pytorch:mainfrom
wukaixingxp:feature/auto-batch-state-dict
Open

Feature: Make batching automatic in put_state_dict operations#122
wukaixingxp wants to merge 10 commits intometa-pytorch:mainfrom
wukaixingxp:feature/auto-batch-state-dict

Conversation

@wukaixingxp
Copy link

Make batching automatic in put_state_dict operations

Summary

Makes batching completely automatic inside put_state_dict()
operations - no API changes required, batching happens transparently.

Key Changes

  • Internal _put_batch() - Parallelizes storage + batches
    controller notifications
  • Controller notify_put_batch() - Single RPC instead of N
    individual calls
  • Auto-batching in put_state_dict() - Uses _put_batch()
    internally
  • Parallel get_state_dict() - Fetches tensors concurrently

Performance Impact

  • put_state_dict: 1 RPC instead of N RPCs for notifications
  • Parallel storage operations improve throughput
  • Zero code changes needed for users

Dependencies

Independent of PRs #120 and #121 - can merge in any order

Developer and others added 3 commits February 7, 2026 00:13
Implements _put_batch() as an internal method that parallelizes storage
operations and batches controller notifications into a single RPC call.
This provides the foundation for automatic batching in put_state_dict.

Key optimizations:
- Parallel asyncio.gather for concurrent storage operations
- Single notify_put_batch RPC instead of N individual notify_put calls
- Each operation gets its own transport buffer to avoid race conditions

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Implements batched notification endpoint that processes multiple put
operations in a single RPC call, reducing round-trip overhead when
storing many tensors (e.g., model state_dicts).

This complements the _put_batch method in LocalClient and enables
efficient automatic batching in put_state_dict operations.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Changes put_state_dict to automatically batch all tensor storage operations
internally using _put_batch, eliminating the need for users to manually call
a separate batch API. Also enables parallel fetching in get_state_dict.

Benefits:
- put_state_dict now uses single RPC for all parameters (was N RPCs)
- Parallel storage operations improve throughput
- get_state_dict fetches all tensors concurrently
- Batching is transparent - no API changes for users

This addresses reviewer feedback to make batching the default behavior
rather than requiring explicit put_batch API calls.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 7, 2026
@codecov-commenter
Copy link

codecov-commenter commented Feb 7, 2026

Codecov Report

❌ Patch coverage is 30.97643% with 205 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (main@395aa69). Learn more about missing BASE report.

Files with missing lines Patch % Lines
tests/bench_weight_sync.py 0.00% 108 Missing ⚠️
tests/test_put_batch.py 62.83% 42 Missing ⚠️
torchstore/client.py 10.34% 26 Missing ⚠️
torchstore/state_dict_utils.py 14.28% 24 Missing ⚠️
torchstore/api.py 42.85% 4 Missing ⚠️
torchstore/controller.py 91.66% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main     #122   +/-   ##
=======================================
  Coverage        ?   58.43%           
=======================================
  Files           ?       32           
  Lines           ?     3318           
  Branches        ?        0           
=======================================
  Hits            ?     1939           
  Misses          ?     1379           
  Partials        ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@wukaixingxp
Copy link
Author

Weight Sync Benchmark: feature/auto-batch-state-dict vs main

What This Tests

This benchmark isolates the performance of put_state_dict and get_state_dict in torchstore -- the core operations used for weight synchronization between trainer and generator in forge.

Branch Changes

The feature/auto-batch-state-dict branch modifies two functions in torchstore/state_dict_utils.py:

put_state_dict (writing weights to the store):

  • main: Sequentially calls await store.put() for each parameter, one at a time. Each call makes its own RPC to notify the controller. For a model with N parameters, this is N sequential puts + N sequential RPCs.
  • feature: Batches all N puts in parallel using asyncio.gather, then notifies the controller with a single notify_put_batch RPC call. This is N parallel puts + 1 RPC.

get_state_dict (reading weights from the store):

  • main: Sequentially calls await store.get() for each parameter, one at a time. The parallel version was commented out.
  • feature: Runs all N gets concurrently using asyncio.gather.

How The Benchmark Works

The benchmark (tests/bench_weight_sync.py) does the following:

  1. Creates a synthetic state dict matching Qwen3-4B model structure:

    • 327 parameters (36 transformer layers)
    • hidden_size=2560, intermediate_size=8960, vocab_size=151936
    • Total size: 7.82 GB in bfloat16
    • Tensor shapes match real model: attention projections (2560x2560), MLP layers (8960x2560), embeddings (151936x2560), layer norms (2560,)
  2. Runs inside a Monarch Actor (required by torchstore architecture)

  3. Executes warmup iterations, then timed iterations of:

    • ts.put_state_dict(state_dict, key) -- store all weights
    • ts.get_state_dict(key, state_dict) -- retrieve all weights
  4. Reports per-iteration timing and aggregate statistics (mean, median, stdev, min, max, throughput in GB/s)

Running The Benchmark

# Default: Qwen3-4B dimensions, 5 iterations, 1 warmup
pytest tests/bench_weight_sync.py -v -s

# Custom iterations
BENCH_ITER=3 BENCH_WARMUP=1 pytest tests/bench_weight_sync.py -v -s

# Qwen3-30B dimensions (larger model)
BENCH_MODEL=synthetic-large BENCH_ITER=3 pytest tests/bench_weight_sync.py -v -s

# Real HuggingFace model (requires HF_TOKEN)
BENCH_MODEL=qwen3-4b BENCH_USE_HF=1 pytest tests/bench_weight_sync.py -v -s

Environment variables:

Variable Default Description
BENCH_MODEL synthetic-small synthetic-small, synthetic-large, qwen3-4b, qwen3-30b
BENCH_ITER 5 Number of timed iterations
BENCH_WARMUP 1 Number of warmup iterations
BENCH_USE_HF 0 Set to 1 to load real HuggingFace model

Results

Machine: gpu-dev-9ce1cf6f (1.9 TB RAM, 4x GPU)
Transport: MonarchRPC
State dict: 327 parameters, 7.82 GB (bfloat16, synthetic Qwen3-4B structure)
Iterations: 3 timed + 1 warmup

Feature Branch (batched)

Metric                     Mean     Median      Stdev        Min        Max
----------------------------------------------------------------------
put_state_dict          10.759s    10.819s     0.195s    10.540s    10.917s
get_state_dict          11.862s    11.963s     0.448s    11.372s    12.252s
round_trip              22.621s    22.781s     0.287s    22.289s    22.792s

Throughput                 Mean       Peak
----------------------------------------
put (GB/s)                0.73       0.74
get (GB/s)                0.66       0.69
round_trip (GB/s)         0.35       0.35

Main Branch (sequential)

Metric                     Mean     Median      Stdev        Min        Max
----------------------------------------------------------------------
put_state_dict          19.297s    19.296s     0.016s    19.281s    19.313s
get_state_dict          22.526s    22.302s     0.485s    22.194s    23.083s
round_trip              41.823s    41.598s     0.500s    41.475s    42.396s

Throughput                 Mean       Peak
----------------------------------------
put (GB/s)                0.41       0.41
get (GB/s)                0.35       0.35
round_trip (GB/s)         0.19       0.19

Comparison

Operation main feature Speedup
put_state_dict 19.30s 10.76s 1.79x
get_state_dict 22.53s 11.86s 1.90x
Round-trip 41.82s 22.62s 1.85x

The feature branch is ~1.85x faster overall for weight sync operations

@wukaixingxp wukaixingxp marked this pull request as ready for review February 10, 2026 19:19
"qwen3-30b": (4096, 11008, 151936, 48),
}

HF_MODEL_MAP = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just take the HF model names directly? Don't think mapping is necessary

}


def create_synthetic_state_dict(model_type, dtype=torch.bfloat16):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this useful? I think in most cases you want to benchmark with a real model, maybe smaller

get_times = []

for i in range(self.iterations):
key = f"bench_{i}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Store takes advantage of cached in-place gets / puts. We should keep the key prefix static (maybe just use v0)

# Warmup
print(f"[Actor] Running {self.warmup} warmup iteration(s)...")
for i in range(self.warmup):
await ts.put_state_dict(state_dict, f"warmup_{i}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as below, use the same key so that the allocation caches are warmed up


assert model_type in MODEL_CONFIGS, f"Unknown model: {model_type}"

# Disable transports that may not be available
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this isn't done and we just use configurable transport type to the strategy

Args:
items: Dictionary mapping keys to values to store.
"""
if not items:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert is not none?


latency_tracker = LatencyTracker(f"put_batch:{len(items)}_keys")

# Select storage volume (all items go to same volume)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually always true? I'm not sure on the flow @LucasLLC


# Select storage volume (all items go to same volume)
storage_volume_ref = self.strategy.select_storage_volume()
latency_tracker.track_step("select storage volume")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "strategy.select_storage_volume"

(key, request.meta_only(), storage_volume_ref.volume_id)
for key, request in put_results
]
await self._controller.notify_put_batch.call(notifications)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool!

notifications: List of (key, request, storage_volume_id) tuples.
"""
self.assert_initialized()
for key, request, storage_volume_id in notifications:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think this is just duplicate code from notify_put (single), we should move to helper

logger = getLogger(__name__)


async def put_state_dict(store, state_dict, key):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's expose these as new functions for safer rollout? put_state_dict_batch, get_state_dict_batch

@amirafzali
Copy link
Member

DCPParityTest failed. I think that's worth investigating

self.warmup = warmup

@endpoint
async def run_benchmark(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also have a verification step asserting equality between the original model and retrieved one. Could be a checksum comparison if its too large to torch.equal... not sure

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added, now I got this using shared memory:

  warnings.warn('resource_tracker: There appear to be %d '
TORCHSTORE_RDMA_ENABLED=0 USE_TORCHCOMMS_RDMA=0 \
  BENCH_ITER=2 BENCH_WARMUP=1 \
  pytest tests/bench_weight_sync.py -vv -s
============================================== test session starts ==============================================
platform linux -- Python 3.12.12, pytest-9.0.2, pluggy-1.6.0 -- /home/dev/.conda/envs/vllm/bin/python3.12
cachedir: .pytest_cache
rootdir: /home/dev/kai/torchstore
configfile: pyproject.toml
plugins: anyio-4.12.1, asyncio-1.3.0, unordered-0.7.0, typeguard-4.4.4
asyncio: mode=Mode.STRICT, debug=False, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collecting ... Warning: setting HYPERACTOR_CODEC_MAX_FRAME_LENGTH since this needs to be set to enable large RPC calls via Monarch
collected 1 item

tests/bench_weight_sync.py::test_benchmark_weight_sync
Model config: hidden=2560, intermediate=8960, vocab=151936, layers=36

======================================================================
WEIGHT SYNC BENCHMARK
======================================================================
Model:            Qwen/Qwen3-4B
Transport:        SharedMemory
Max concurrent:   8
Iterations:       2 (+ 1 warmup)
======================================================================

Initializing TorchStoreStrategy with default_transport_type=<TransportType.SharedMemory: 6>
Monarch internal logs are being written to /tmp/dev/monarch_log.log; execution id dev_Feb-10_22:46_43
[Actor] Loading HuggingFace model: Qwen/Qwen3-4B...
[actor=<root>.<tests.bench_weight_sync.BenchmarkActor bench_actor_f0a82f01{'gpus': 0/1}>] `torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|███████████████████████████████████████████████████| 3/3 [00:00<00:00, 78.90it/s]
[Actor] State dict: 399 parameters, 8.22 GB
[Actor] Running 1 warmup iteration(s)...
[Actor] Warmup complete.
[Actor] Running 2 timed iteration(s)...
  Iter 1: put=0.938s (8.76 GB/s), get=1.033s (7.95 GB/s), total=1.971s
  Iter 2: put=0.920s (8.93 GB/s), get=1.032s (7.96 GB/s), total=1.953s
[Actor] Verifying correctness of last get...
[Actor] Verification passed.

======================================================================
RESULTS SUMMARY
======================================================================
Parameters:       399
State dict size:  8414.1 MB (8.22 GB)
======================================================================
Metric                     Mean     Median      Stdev        Min        Max
----------------------------------------------------------------------
put_state_dict_batch     0.929s     0.929s     0.012s     0.920s     0.938s
get_state_dict_batch     1.033s     1.033s     0.001s     1.032s     1.033s
round_trip               1.962s     1.962s     0.013s     1.953s     1.971s

Throughput                 Mean       Peak
----------------------------------------
put (GB/s)                8.85       8.93
get (GB/s)                7.96       7.96
round_trip (GB/s)         4.19       4.21
======================================================================
PASSED

============================================== 1 passed in 43.69s =============================================== ```

Copy link
Contributor

@LucasLLC LucasLLC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wukaixingxp do you see any value in maintaining the old way of doing put/get state dict? shouldn't we just always do this?

TYSM! Looking forward to seeing these perf gains!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants