Skip to content

Conversation

@daihaowz
Copy link
Collaborator

@daihaowz daihaowz commented Dec 7, 2025

Description

In single controller mode, all data passes through the central side, resulting in a single point bottleneck. In this PR, we have implemented DistributedBatchMemory, the control plane only passes metadata, and actual data is fetched on-demand via HTTP from distributed nodes. Thus completely solving the bottleneck on the central side

Related Issue

Fixes #(issue)

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not
    work as expected)
  • Documentation update
  • Code refactoring (no functional changes)
  • Performance improvement
  • Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • I have run formatting tools (pre-commit or manual)
  • I have run relevant unit tests and they pass
  • I have added tests for new functionality
  • I have updated documentation if needed
  • My branch is up to date with main
  • This PR introduces breaking changes (if yes, fill out details below)
  • If this PR changes documentation, I have built and previewed it locally with
    jb build docs
  • No critical issues raised by AI reviewers (/gemini review)

Breaking Change Details (if applicable):

Additional Context


Need help? Check the Contributing Guide or ask in
GitHub Discussions!

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @daihaowz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a fundamental architectural shift in how batch data is handled, moving from a centralized memory-based approach to a distributed, metadata-driven system. The primary goal is to alleviate performance bottlenecks by allowing data to reside on worker nodes and be fetched on-demand, rather than passing all data through a central controller. This change enables more scalable and efficient data processing in distributed environments by separating data management from control flow.

Highlights

  • Distributed Batch Memory Implementation: Introduced a new DistributedBatchMemory class that separates batch metadata from the actual data. This allows the control plane to pass only metadata, while data is fetched on-demand from distributed nodes via HTTP, effectively eliminating the single-point bottleneck of central data processing.
  • Metadata-Driven Data Access: Implemented BatchMetadata, ShardMetadata, and TensorMetadata dataclasses to describe the structure, location, and properties of distributed data shards. The DistributedBatchMemory class now operates primarily on this metadata, fetching data lazily when required.
  • Asynchronous Data Client: Added a BatchDataClient for asynchronous HTTP communication with worker nodes to fetch and store data shards. This client handles the network requests for distributed data retrieval and storage.
  • RPC Server Enhancements for Distributed Data: The RPC server (rpc_server.py) has been significantly updated to host and serve distributed batch data. New endpoints (/data/<shard_id>, /data/clear, /data/stats) are introduced for storing, retrieving, and managing data shards. It also includes a dedicated NCCL worker thread to ensure compatibility with NCCL operations while allowing concurrent data requests.
  • Integration with Controllers and Serialization: Updated RolloutController and TrainController to utilize the new DistributedBatchMemory for handling batches, including methods for chunking, unioning, and clearing distributed data. The RPC serialization logic has also been extended to correctly serialize and deserialize DistributedBatchMemory objects.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-designed feature for distributed batch handling. By passing metadata instead of full data payloads, it effectively addresses the single-point bottleneck at the central controller. The implementation includes lazy data fetching via an HTTP client, updates to the RPC server to handle metadata and serve data shards, and comprehensive tests. The use of a dedicated NCCL worker thread in the RPC server to avoid concurrency issues is a particularly strong design choice. I've included a couple of suggestions for minor improvements related to efficiency and code clarity, but overall, this is an excellent contribution.

"""Merge two batches in metadata mode by modifying self in-place."""
# Combine shards from both batches
all_shards = self.metadata.shards + other.metadata.shards
max_global_step = max(self.metadata.global_step, other.metadata.global_step)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why take the maximum?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as #687 (comment)

use maximum to prevent gc batches which are in use

"""Merge two batches in metadata mode by modifying self in-place."""
# Combine shards from both batches
all_shards = self.metadata.shards + other.metadata.shards
max_global_step = max(self.metadata.global_step, other.metadata.global_step)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify under what circumstances we expect to merge batches with different steps?
The choice of max() implies a specific semantic: "the merged batch represents the latest step." But is this the intended behavior?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When rollout called, train controller -> rollout controller -> infer engine, the global_step returns is the version of the inference engine
When compute_logp called, train controller -> train controller -> train engine, the global_step returns is the version of the train engine

I understand that there is a natural difference between the two in async mode

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add some debug log in trainer like this:
image

and log is as belows when max_head_offpolicyness=2:

# grep "after " grpo-new-1.log  | grep global_step
after prepare_batch, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=fbf47d65-ccc0-4d27-802f-3451c845dd00, num_shards=16, global_step=0, data_loaded=False>
after recompute_logprob, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=d632a5ce-470d-4743-8462-a1d897fd75ab, num_shards=20, global_step=0, data_loaded=False>
after compute_advantage, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=754e8a47-5bc4-4a76-9f80-4d776aa2ae43, num_shards=4, global_step=0, data_loaded=False>
---
after prepare_batch, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=0fc36e74-26b5-4ea6-90a2-e50845655aa4, num_shards=16, global_step=0, data_loaded=False>
after recompute_logprob, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=64cd5d46-d3cc-47ae-81b9-2c4d8e48bd9e, num_shards=20, global_step=1, data_loaded=False>
after compute_advantage, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=0d528e4c-0b5c-4258-bd21-029b20b84015, num_shards=4, global_step=1, data_loaded=False>
---
after prepare_batch, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=036fdfcb-7cb6-4437-9895-7db7b1742ebb, num_shards=16, global_step=0, data_loaded=False>
after recompute_logprob, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=6e254a5f-2578-47f8-bb2e-6eb75a2b5af5, num_shards=20, global_step=2, data_loaded=False>
after compute_advantage, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=fbfadff4-16d5-4ea4-a332-0636093a7029, num_shards=4, global_step=2, data_loaded=False>
---
after prepare_batch, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=b4742c7d-3363-485a-950f-ab523f1f36d8, num_shards=16, global_step=2, data_loaded=False>
after recompute_logprob, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=c0c31b6c-e086-4041-98c3-c6743db8e979, num_shards=20, global_step=3, data_loaded=False>
after compute_advantage, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=775f0c26-406f-4d80-9095-29b48f28a757, num_shards=4, global_step=3, data_loaded=False>
---
after prepare_batch, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=5073d258-c084-4f57-9503-128b386ab593, num_shards=16, global_step=3, data_loaded=False>
after recompute_logprob, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=da2f4daf-bb25-4fe5-a0f0-255f20c0855f, num_shards=20, global_step=4, data_loaded=False>
after compute_advantage, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=867be65f-e27d-4a20-94d5-138ba6b1f043, num_shards=4, global_step=4, data_loaded=False>
---
after prepare_batch, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=6374ba00-adc3-4afa-b815-ee0bcd9cea70, num_shards=16, global_step=4, data_loaded=False>
after recompute_logprob, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=c51e1dd7-824a-40cf-9172-1fe90ab72a67, num_shards=20, global_step=5, data_loaded=False>
after compute_advantage, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=b90398e3-8f8d-4a62-b4f3-622ed2db34ae, num_shards=4, global_step=5, data_loaded=False>

Introduces an explicit `BatchMode` enum to `DistributedBatchMemory` to strictly distinguish between local data and remote metadata states, ensuring consistency during operations like chunking and concatenation. Renames RPC keyword arguments to remove leading underscores for better API clarity (e.g., `should_broadcast`, `result_key`).
rchardx and others added 5 commits December 8, 2025 16:15
Updates the string representation of `DistributedBatchMemory` to display "size" instead of "total_size" and adjusts unit tests to match.
Simplifies docstrings for several internal methods by removing verbose parameter and return value descriptions.
- Add guards to `DistributedBatchMemory` to prevent item access or deletion while in remote mode or when the dataset is empty, ensuring data is fetched first.
- Update `__setitem__` to remove the implicit union merging behavior and strictly enforce string keys.
- Remove a redundant null check when filtering trajectories in the rollout controller.
…nnections

- Support direct merging of DistributedBatchMemory instances via assignment by invoking the union method.
- Introduce a connection limit to the BatchDataClient to manage concurrent HTTP connections more effectively.
- Refactor session management in BatchDataClient to use context managers and share TCP connectors, preventing resource exhaustion and improving cleanup.
Comment on lines +31 to +46
class BatchMode(Enum):
"""Explicit mode enum for DistributedBatchMemory.
Attributes
----------
LOCAL : auto
Data stored locally in memory
REMOTE : auto
Only metadata; data fetched on-demand via HTTP
EMPTY : auto
Neither data nor metadata present (invalid/empty state)
"""

LOCAL = auto() # Data stored locally in memory
REMOTE = auto() # Only metadata; data fetched on-demand
EMPTY = auto() # Neither present (invalid state)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better renamed to BatchStatus

def union(self, other: "DistributedBatchMemory") -> "DistributedBatchMemory":
"""Merge another batch with this one"""
merged_data = {k: v for k, v in self.dataset.items()}
def union(self, other: DistributedBatchMemory) -> DistributedBatchMemory:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend implementing union as an out-of-place operation, aka it should return a new DistributedBatch, otherwise it should be called union_, just like the naming convension of pytorch.

return self.dispatcher.runner

# ==================== DISTRIBUTED BATCH RPC WRAPPERS ====================
def clear_batches(self, global_step: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clearing batches by global step is not a good idea. Due to asynchronous training, the definition of training step is kind of vague. I recommend clearing by task ids.

Comment on lines +29 to +30
batch_size: int
offset: int = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need batch size or offset. We can just store data individually according to prompt/input/task ID. Each data corresponds to an entry in the training dataset.

"""Metadata for a distributed batch sharded across multiple nodes."""

batch_id: str
global_step: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only require shard IDs for clearing up data. What else is this entry used for?

app = Flask(__name__)


def _init_nccl_worker():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice workaround!


data_bytes = request.get_data()
buffer = io.BytesIO(data_bytes)
data = pickle.load(buffer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, can we just avoid pickling and implementing a customized serialization protocol in scheduler/rpc/serialization.py?

Comment on lines +638 to +639
_batch_storage[shard_id] = (global_step, data)
_batch_storage_stats[shard_id] = len(data_bytes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here shard_id should be task_id + key, e.g., as a dataclass, which can uniquely point to a tensor across the training procedure.

# Import here to avoid circular dependency
if isinstance(value, DistributedBatchMemory):
# Use __getstate__ to get serializable state
state = value.__getstate__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can just copy past the getstate and setstate functions here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants