Skip to content

Add logic to stream weights in EmbeddingKVDB #2930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

chouxi
Copy link

@chouxi chouxi commented Apr 30, 2025

Summary:
Gated by enable_raw_embedding_streaming
Add the logic to send the passed in tensors to TrainingParameterServerService thrift service in EmbeddingKVDB
The passed in

  • table_names to get the table FQN when streaming
  • table_offsets to get the global row id across TBEs.
  • table_sizes to get size of each table in TBE to infer which table a specific row belongs to.
  • ps_server_port is the port that runs the local TrainingParameterServerService to stream tensors to.

It creates a new thread weights_stream_thread_ in EmbeddingKBDB to stream the weights out of trainers asynchronously.

Differential Revision: D73792631

chouxi added 2 commits April 30, 2025 16:01
…ytorch#2928)

Summary:
X-link: facebookresearch/FBGEMM#1138

X-link: pytorch/FBGEMM#4053


As titled, add this option all the way to gate the upcoming changes of raw embedding streaming in SSDTBE.

Differential Revision: D73691088
Summary:
Gated by enable_raw_embedding_streaming
Add the logic to send the passed in tensors to `TrainingParameterServerService` thrift service in EmbeddingKVDB
The passed in
- `table_names` to get the table FQN when streaming
- `table_offsets` to get the global row id across TBEs.
- `table_sizes` to get size of each table in TBE to infer which table a specific row belongs to.
- `ps_server_port` is the port that runs the local `TrainingParameterServerService` to stream tensors to.

It creates a new thread `weights_stream_thread_` in EmbeddingKBDB to stream the weights out of trainers asynchronously.

Differential Revision: D73792631
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 30, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73792631

chouxi added a commit to chouxi/FBGEMM that referenced this pull request Apr 30, 2025
Summary:
X-link: pytorch/torchrec#2930

Gated by enable_raw_embedding_streaming
Add the logic to send the passed in tensors to `TrainingParameterServerService` thrift service in EmbeddingKVDB
The passed in
- `table_names` to get the table FQN when streaming
- `table_offsets` to get the global row id across TBEs.
- `table_sizes` to get size of each table in TBE to infer which table a specific row belongs to.
- `ps_server_port` is the port that runs the local `TrainingParameterServerService` to stream tensors to.

It creates a new thread `weights_stream_thread_` in EmbeddingKBDB to stream the weights out of trainers asynchronously.

Differential Revision: D73792631
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants