-
Notifications
You must be signed in to change notification settings - Fork 243
feat: implement distributed batch #687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @daihaowz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a fundamental architectural shift in how batch data is handled, moving from a centralized memory-based approach to a distributed, metadata-driven system. The primary goal is to alleviate performance bottlenecks by allowing data to reside on worker nodes and be fetched on-demand, rather than passing all data through a central controller. This change enables more scalable and efficient data processing in distributed environments by separating data management from control flow. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and well-designed feature for distributed batch handling. By passing metadata instead of full data payloads, it effectively addresses the single-point bottleneck at the central controller. The implementation includes lazy data fetching via an HTTP client, updates to the RPC server to handle metadata and serve data shards, and comprehensive tests. The use of a dedicated NCCL worker thread in the RPC server to avoid concurrency issues is a particularly strong design choice. I've included a couple of suggestions for minor improvements related to efficiency and code clarity, but overall, this is an excellent contribution.
| """Merge two batches in metadata mode by modifying self in-place.""" | ||
| # Combine shards from both batches | ||
| all_shards = self.metadata.shards + other.metadata.shards | ||
| max_global_step = max(self.metadata.global_step, other.metadata.global_step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why take the maximum?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as #687 (comment)
use maximum to prevent gc batches which are in use
| """Merge two batches in metadata mode by modifying self in-place.""" | ||
| # Combine shards from both batches | ||
| all_shards = self.metadata.shards + other.metadata.shards | ||
| max_global_step = max(self.metadata.global_step, other.metadata.global_step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify under what circumstances we expect to merge batches with different steps?
The choice of max() implies a specific semantic: "the merged batch represents the latest step." But is this the intended behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When rollout called, train controller -> rollout controller -> infer engine, the global_step returns is the version of the inference engine
When compute_logp called, train controller -> train controller -> train engine, the global_step returns is the version of the train engine
I understand that there is a natural difference between the two in async mode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I add some debug log in trainer like this:

and log is as belows when max_head_offpolicyness=2:
# grep "after " grpo-new-1.log | grep global_step
after prepare_batch, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=fbf47d65-ccc0-4d27-802f-3451c845dd00, num_shards=16, global_step=0, data_loaded=False>
after recompute_logprob, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=d632a5ce-470d-4743-8462-a1d897fd75ab, num_shards=20, global_step=0, data_loaded=False>
after compute_advantage, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=754e8a47-5bc4-4a76-9f80-4d776aa2ae43, num_shards=4, global_step=0, data_loaded=False>
---
after prepare_batch, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=0fc36e74-26b5-4ea6-90a2-e50845655aa4, num_shards=16, global_step=0, data_loaded=False>
after recompute_logprob, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=64cd5d46-d3cc-47ae-81b9-2c4d8e48bd9e, num_shards=20, global_step=1, data_loaded=False>
after compute_advantage, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=0d528e4c-0b5c-4258-bd21-029b20b84015, num_shards=4, global_step=1, data_loaded=False>
---
after prepare_batch, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=036fdfcb-7cb6-4437-9895-7db7b1742ebb, num_shards=16, global_step=0, data_loaded=False>
after recompute_logprob, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=6e254a5f-2578-47f8-bb2e-6eb75a2b5af5, num_shards=20, global_step=2, data_loaded=False>
after compute_advantage, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=fbfadff4-16d5-4ea4-a332-0636093a7029, num_shards=4, global_step=2, data_loaded=False>
---
after prepare_batch, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=b4742c7d-3363-485a-950f-ab523f1f36d8, num_shards=16, global_step=2, data_loaded=False>
after recompute_logprob, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=c0c31b6c-e086-4041-98c3-c6743db8e979, num_shards=20, global_step=3, data_loaded=False>
after compute_advantage, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=775f0c26-406f-4d80-9095-29b48f28a757, num_shards=4, global_step=3, data_loaded=False>
---
after prepare_batch, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=5073d258-c084-4f57-9503-128b386ab593, num_shards=16, global_step=3, data_loaded=False>
after recompute_logprob, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=da2f4daf-bb25-4fe5-a0f0-255f20c0855f, num_shards=20, global_step=4, data_loaded=False>
after compute_advantage, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=867be65f-e27d-4a20-94d5-138ba6b1f043, num_shards=4, global_step=4, data_loaded=False>
---
after prepare_batch, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=6374ba00-adc3-4afa-b815-ee0bcd9cea70, num_shards=16, global_step=4, data_loaded=False>
after recompute_logprob, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=c51e1dd7-824a-40cf-9172-1fe90ab72a67, num_shards=20, global_step=5, data_loaded=False>
after compute_advantage, batch: DistributedBatchMemory<mode=REMOTE, size=64, batch_id=b90398e3-8f8d-4a62-b4f3-622ed2db34ae, num_shards=4, global_step=5, data_loaded=False>Introduces an explicit `BatchMode` enum to `DistributedBatchMemory` to strictly distinguish between local data and remote metadata states, ensuring consistency during operations like chunking and concatenation. Renames RPC keyword arguments to remove leading underscores for better API clarity (e.g., `should_broadcast`, `result_key`).
9714c82 to
8205d01
Compare
Updates the string representation of `DistributedBatchMemory` to display "size" instead of "total_size" and adjusts unit tests to match. Simplifies docstrings for several internal methods by removing verbose parameter and return value descriptions.
- Add guards to `DistributedBatchMemory` to prevent item access or deletion while in remote mode or when the dataset is empty, ensuring data is fetched first. - Update `__setitem__` to remove the implicit union merging behavior and strictly enforce string keys. - Remove a redundant null check when filtering trajectories in the rollout controller.
…nnections - Support direct merging of DistributedBatchMemory instances via assignment by invoking the union method. - Introduce a connection limit to the BatchDataClient to manage concurrent HTTP connections more effectively. - Refactor session management in BatchDataClient to use context managers and share TCP connectors, preventing resource exhaustion and improving cleanup.
| class BatchMode(Enum): | ||
| """Explicit mode enum for DistributedBatchMemory. | ||
| Attributes | ||
| ---------- | ||
| LOCAL : auto | ||
| Data stored locally in memory | ||
| REMOTE : auto | ||
| Only metadata; data fetched on-demand via HTTP | ||
| EMPTY : auto | ||
| Neither data nor metadata present (invalid/empty state) | ||
| """ | ||
|
|
||
| LOCAL = auto() # Data stored locally in memory | ||
| REMOTE = auto() # Only metadata; data fetched on-demand | ||
| EMPTY = auto() # Neither present (invalid state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better renamed to BatchStatus
| def union(self, other: "DistributedBatchMemory") -> "DistributedBatchMemory": | ||
| """Merge another batch with this one""" | ||
| merged_data = {k: v for k, v in self.dataset.items()} | ||
| def union(self, other: DistributedBatchMemory) -> DistributedBatchMemory: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend implementing union as an out-of-place operation, aka it should return a new DistributedBatch, otherwise it should be called union_, just like the naming convension of pytorch.
| return self.dispatcher.runner | ||
|
|
||
| # ==================== DISTRIBUTED BATCH RPC WRAPPERS ==================== | ||
| def clear_batches(self, global_step: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clearing batches by global step is not a good idea. Due to asynchronous training, the definition of training step is kind of vague. I recommend clearing by task ids.
| batch_size: int | ||
| offset: int = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need batch size or offset. We can just store data individually according to prompt/input/task ID. Each data corresponds to an entry in the training dataset.
| """Metadata for a distributed batch sharded across multiple nodes.""" | ||
|
|
||
| batch_id: str | ||
| global_step: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only require shard IDs for clearing up data. What else is this entry used for?
| app = Flask(__name__) | ||
|
|
||
|
|
||
| def _init_nccl_worker(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice workaround!
|
|
||
| data_bytes = request.get_data() | ||
| buffer = io.BytesIO(data_bytes) | ||
| data = pickle.load(buffer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, can we just avoid pickling and implementing a customized serialization protocol in scheduler/rpc/serialization.py?
| _batch_storage[shard_id] = (global_step, data) | ||
| _batch_storage_stats[shard_id] = len(data_bytes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here shard_id should be task_id + key, e.g., as a dataclass, which can uniquely point to a tensor across the training procedure.
| # Import here to avoid circular dependency | ||
| if isinstance(value, DistributedBatchMemory): | ||
| # Use __getstate__ to get serializable state | ||
| state = value.__getstate__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can just copy past the getstate and setstate functions here.
Description
In single controller mode, all data passes through the central side, resulting in a single point bottleneck. In this PR, we have implemented DistributedBatchMemory, the control plane only passes metadata, and actual data is fetched on-demand via HTTP from distributed nodes. Thus completely solving the bottleneck on the central side
Related Issue
Fixes #(issue)
Type of Change
work as expected)
Checklist
jb build docs/gemini review)Breaking Change Details (if applicable):
Additional Context
Need help? Check the Contributing Guide or ask in
GitHub Discussions!